C++ 如何改进矩阵乘法的内存管理

C++ 如何改进矩阵乘法的内存管理,c++,memory-management,matrix-multiplication,C++,Memory Management,Matrix Multiplication,我试图学习矩阵乘法,并遇到Strassen乘法与标准矩阵乘法的代码,所以我尝试实现它。然而,这段代码使用了太多的内存,以至于当矩阵足够大时,它会杀死程序。此外,由于它使用了太多的内存,因此处理时间更长 由于我不完全理解复杂的内存管理,所以我不太愿意过多地处理代码,我真的很想了解这个主题 在代码中有一个cut参数,发现at 320使它运行得更快,并且似乎在内存管理方面有所改进 编辑。我实现了一个复制构造函数、析构函数和一个跟踪内存使用情况的函数,它修复了内存泄漏问题,但Strassen矩阵在199

我试图学习矩阵乘法,并遇到Strassen乘法与标准矩阵乘法的代码,所以我尝试实现它。然而,这段代码使用了太多的内存,以至于当矩阵足够大时,它会杀死程序。此外,由于它使用了太多的内存,因此处理时间更长

由于我不完全理解复杂的内存管理,所以我不太愿意过多地处理代码,我真的很想了解这个主题

在代码中有一个cut参数,发现at 320使它运行得更快,并且似乎在内存管理方面有所改进

编辑。我实现了一个复制构造函数、析构函数和一个跟踪内存使用情况的函数,它修复了内存泄漏问题,但Strassen矩阵在1990维到2100维之间的时间上仍有很大的飞跃

矩阵h

#ifndef MATRIX_H

#define MATRIX_H



#include <vector>

using namespace std;



class matrix

{

public:

    matrix(int dim, bool random, bool strassen);

    matrix(const matrix& old_m);



    inline int dim() {

        return dim_;

    }

    inline int& operator()(unsigned row, unsigned col) {

        return data_[dim_ * row + col];

    }



    inline int operator()(unsigned row, unsigned col) const {

        return data_[dim_ * row + col];

    }



    void print();

    matrix operator+(matrix b);

    matrix operator-(matrix b);

    ~matrix();



private:

    int dim_;

    int* data_;

};



#endif

此外,当维度从1900变为2100时,它不应该在时间上有如此大的跳跃。

问题中需要包含代码。指向代码的链接是不可接受的,因为它们可能会中断并变得不可访问。看,我看了你的密码。没有必要再进一步,因为您未能实现。因此,
operator+
被破坏(加上您复制
矩阵
对象的代码的任何部分),您的问题可能与矩阵大小无关。因此,一旦你在这里发布代码,就要准备好对你的问题进行“重复”的结束。
#include <iostream>

#include <vector>

#include <stdlib.h>

#include <time.h>

#include "SAMmatrix.h"

using namespace std;



matrix::matrix(int dim, bool random, bool strassen) : dim_(dim) {

    if (strassen) {

        int dim2 = 2;

        while (dim2 < dim)

            dim2 *= 2;

        dim_ = dim2;

    }



    data_ = new int[dim_ * dim_];

    if (!random) return;



    for (int i = 0; i < dim_ * dim_; i++)

        data_[i] = rand() % 10;

}

matrix::matrix(const matrix& old_m){

    dim_ = old_m.dim_;  

    data_ = new int[dim_ * dim_];

    for (int i = 0; i < dim_ * dim_; i++)

        data_[i] = old_m.data_[i]; 

}



void matrix::print() {

    for (int i = 0; i < dim_; i++) {

        for (int j = 0; j < dim_; j++)

            cout << (*this)(i, j) << " ";

        cout << "\n";

    }

    cout << "\n";

}



matrix matrix::operator+(matrix b) {

    matrix c(dim_, false, false);

    for (int i = 0; i < dim_; i++)

        for (int j = 0; j < dim_; j++)

            c(i, j) = (*this)(i, j) + b(i, j);



    return c;

}



matrix matrix::operator-(matrix b) {

    matrix c(dim_, false, false);

    for (int i = 0; i < dim_; i++)

        for (int j = 0; j < dim_; j++)

            c(i, j) = (*this)(i, j) - b(i, j);



    return c;

}

matrix::~matrix()

{

    delete [] data_;

}
#include <iostream>

#include <stdlib.h>

#include <time.h>

#include <sys/time.h>

#include "SAMmatrix.h"

#include "stdlib.h"

#include "stdio.h"

#include "string.h"





typedef pair<matrix, long> result;



int cut = 64;



matrix mult_std(matrix a, matrix b) 

{

    matrix c(a.dim(), false, false);

    for (int i = 0; i < a.dim(); i++)

        for (int k = 0; k < a.dim(); k++)

            for (int j = 0; j < a.dim(); j++)

                c(i, j) += a(i, k) * b(k, j);



    return c;

}



matrix get_part(int pi, int pj, matrix m) 

{

    matrix p(m.dim() / 2, false, true);

    pi = pi * p.dim();

    pj = pj * p.dim();



    for (int i = 0; i < p.dim(); i++)

        for (int j = 0; j < p.dim(); j++)

            p(i, j) = m(i + pi, j + pj);



    return p;

}



void set_part(int pi, int pj, matrix* m, matrix p) 

{

    pi = pi * p.dim();

    pj = pj * p.dim();



    for (int i = 0; i < p.dim(); i++)

        for (int j = 0; j < p.dim(); j++)

            (*m)(i + pi, j + pj) = p(i, j);

}



matrix mult_strassen(matrix a, matrix b) 

{

    if (a.dim() <= cut)

        return mult_std(a, b);



    matrix a11 = get_part(0, 0, a);

    matrix a12 = get_part(0, 1, a);

    matrix a21 = get_part(1, 0, a);

    matrix a22 = get_part(1, 1, a);



    matrix b11 = get_part(0, 0, b);

    matrix b12 = get_part(0, 1, b);

    matrix b21 = get_part(1, 0, b);

    matrix b22 = get_part(1, 1, b);



    matrix m1 = mult_strassen(a11 + a22, b11 + b22);

    matrix m2 = mult_strassen(a21 + a22, b11);

    matrix m3 = mult_strassen(a11, b12 - b22);

    matrix m4 = mult_strassen(a22, b21 - b11);

    matrix m5 = mult_strassen(a11 + a12, b22);

    matrix m6 = mult_strassen(a21 - a11, b11 + b12);

    matrix m7 = mult_strassen(a12 - a22, b21 + b22);



    matrix c(a.dim(), false, true);

    set_part(0, 0, &c, m1 + m4 - m5 + m7);

    set_part(0, 1, &c, m3 + m5);

    set_part(1, 0, &c, m2 + m4);

    set_part(1, 1, &c, m1 - m2 + m3 + m6);



    return c;

}



pair<matrix, long> run(matrix(*f)(matrix, matrix), matrix a, matrix b) 

{

    struct timeval start, end;



    gettimeofday(&start, NULL);

    matrix c = f(a, b);

    gettimeofday(&end, NULL);

    long e = (end.tv_sec * 1000 + end.tv_usec / 1000);

    long s = (start.tv_sec * 1000 + start.tv_usec / 1000);



    return pair<matrix, long>(c, e - s);

}



int parseLine(char* line){ /* overflow*/

    // This assumes that a digit will be found and the line ends in " Kb".

    int i = strlen(line);

    const char* p = line;

    while (*p <'0' || *p > '9') p++;

    line[i-3] = '\0';

    i = atoi(p);

    return i;

}



int getValue(){ //Note: this value is in KB!

    FILE* file = fopen("/proc/self/status", "r");

    int result = -1;

    char line[128];



    while (fgets(line, 128, file) != NULL){

        if (strncmp(line, "VmSize:", 7) == 0){

            result = parseLine(line);

            break;

        }

    }

    fclose(file);

    return result;

}



int main() 

{

    /* test cut of for strassen

    /*

    for (cut = 2; cut <= 512; cut++) {

            matrix a(512, true, true);

            matrix b(512, true, true);

            result r = run(mult_strassen, a, b);

            cout << cut << " " << r.second << "\n";

    }

    */



    /* performance test: standard and strassen */

    /*1024 going up by 64*/

    for (int dim = 1500; dim <= 2300; dim += 200) 

    {

        double space = getValue() * .01;

        cout << "Space before: " << space << "Mb" << "\n";

        matrix a(dim, true, false);

        matrix b(dim, true, false);

        result std = run(mult_std, a, b);

        matrix c(dim, true, true);

        matrix d(dim, true, true);

        result strassen = run(mult_strassen, c, d);

        cout << "Dim " << "   Std  " << " Stranssen" << endl;

        cout << dim << " " << std.second << "ms " << strassen.second << "ms " << "\n";

        double spaceA = getValue() * .01;

        cout << "Space: " << spaceA << "Mb" << "\n";

        cout << " " << endl;
    }

}
1500 2406 4250 
1700 3463 4252 
1900 4819 4247 
2100 6487 30023 
Killed