Algorithm 执行Strassen'的问题;矩阵乘法的s算法

Algorithm 执行Strassen'的问题;矩阵乘法的s算法,algorithm,matrix-multiplication,strassen,Algorithm,Matrix Multiplication,Strassen,在过去的几个小时里,我一直在尝试实现Strassen的矩阵乘法算法,但很难得到正确的乘积。我认为我的助手函数(helpSub、createProd、helpProduct)之一可能是strass2函数的问题或格式(命令顺序等)。任何提示都是受欢迎的,因为我完全被难住了。我一直在使用两个4x4矩阵作为测试矩阵。我在网上看到了大量p1-p7和c1-c4的变体,但似乎都不起作用。下面是我创建的类 /* @author williamnewman public class strassen2 {

在过去的几个小时里,我一直在尝试实现Strassen的矩阵乘法算法,但很难得到正确的乘积。我认为我的助手函数(helpSub、createProd、helpProduct)之一可能是strass2函数的问题或格式(命令顺序等)。任何提示都是受欢迎的,因为我完全被难住了。我一直在使用两个4x4矩阵作为测试矩阵。我在网上看到了大量p1-p7和c1-c4的变体,但似乎都不起作用。下面是我创建的类

 /* @author williamnewman

public class strassen2 {

//Main Strassen multiplication function
//BASE CASE:
int [][] strass2(int[][] x, int[][]y){
    if(x.length == 1 && y.length == 1){
        System.out.println("Donezo");
        int [][] nu = new int[1][1];
        nu[0][0] = x[0][0] * y[0][0];
        return nu;

    }
    else{
   int[][] a,b,c,d,e,f,g,h;
   int dim = x.length/2;

//Dividing two matrices into 8 sub matrices
  System.out.println("A<B<C");
   a = helpSub(0,0,x);
   C(a);
   b = helpSub(0,dim,x);

   C(b);
   c = helpSub(dim,0,x);
   C(c);
   d = helpSub(dim,dim,x);
   C(d);
   e = helpSub(0,0,y);
   C(e);
   f = helpSub(0,dim,y);
   C(f);
   g = helpSub(dim,0,y);
   C(g);
   h = helpSub(dim,dim,y);
   C(h);

   int[][] p1,p2,p3,p4,p5,p6,p7;


//Creating p1-p7
   /
   p1 = strass2(a,subtract(f,h));
   p2 = strass2(h, add(a,b));
   p3 = strass2(e,add(c,d));
   p4 = strass2(d,subtract(g,e));
   p5 = strass2(add(a,d),add(e,h));
   p6 = strass2(subtract(b,d),add(g,h));
   p7 = strass2(subtract(a,c),add(e,f));
   int [][] prod;
   int [][] c1,c2,c3,c4;

//Creating c1-c4
   c1 = subtract(add(p6,p5),subtract(p4,p2));
   c2 = add(p1,p2);
   c3 = add(p3,p4);
   c4 = subtract(add(p1,p5),subtract(p3,p7));
   C(c1);
   System.out.println("C1::");
   C(c2);
   System.out.println("C2::");
   C(c3);
   System.out.println("C3::");
   C(c4);
   System.out.println("C4::");
//CREATES PRODUCT MATRIX
   prod = createProd(c1,c2,c3,c4);
   return prod;

    }




}

//Creates product matrix from c1-c4
int[][] createProd(int[][] c1, int[][] c2, int[][] c3, int[][] c4){
    int[][] product = new int[c1.length*2][c1.length*2];
    int mid = c1.length;
    int fin = c1.length * 2;
    helpProduct(0,0,mid,mid,product,c1);
    helpProduct(0,mid,mid,fin,product,c2);
    helpProduct(mid,0,fin,mid,product,c3);
    helpProduct(mid,mid,fin,fin,product,c4);

     System.out.println();
    System.out.println("PRODUCT::!:");
    C(product);
    return product;



}

    //Helper function to create larger matrix from submatrices
void helpProduct(int x, int y, int z1, int z2,int[][] product, int[][] a1){
    int indR = 0;
    int indC = 0;
    for(int i = x; i < z1; i++){
        indC = 0;
        for(int j = y; j < z2; j++){
            product[i][j] = a1[indR][indC];
            indC++;
        }
        indR++;
    }
}


    int[][] helpSub(int x, int y, int[][] mat){
    int[][] sub = new int[mat.length/2][mat.length/2];
    for(int i1 = 0, i2=x; i1 < (mat.length/2); i1++, i2++)
    for(int j1 = 0, j2=y; j1<(mat.length/2); j1++, j2++)
    {
            sub[i1][j1] = mat[i2][j2];
                           // System.out.println(sub[i1][j1]);
    }
    return sub;
}



//Normal Matrix Multiplication Function
int[][] multiply(int[][]a,int[][]b){
    MM nu = new MM(a,b);
    return nu.product;
}

    //Adds one matrix to the next
int[][] add(int[][]a, int[][]b){
    int [][] nu = new int[a.length][a[0].length];
    for(int i = 0; i < a.length; i++){
        for(int j = 0; j < a[i].length;j++){
            nu[i][j] = a[i][j] + b[i][j];
        }
    }
    return nu;
}

//Subtracts second matrix from the first
int[][] subtract(int[][] a, int[][] b){
    int [][] sub = new int[a.length][a.length];
    //System.out.println("made it");
    for(int i = 0; i < a.length; i++){
        for(int j = 0; j < a[i].length;j++){
            sub[i][j] = a[i][j] - b[i][j];
        }
    }
    return sub;
}
//Prints the matrix
 void C(int[][] product){
    for(int i = 0; i <product.length; i++){
        for(int j = 0; j < product[i].length; j++){
            System.out.print(product[j][i]  + " ");

        }
        System.out.println();
    }
}
}
以下是迄今为止的结果(预期结果为显示的第一个4x4矩阵,实际结果为显示的最后一个4x4矩阵):


我非常确定我的helpSub()函数可以工作,因为它们生成了正确的a-h。但是,我在strass2递归调用中使用的参数可能有问题。如果不够具体,我很抱歉。我只是有点厌倦了,很好奇是否有人看到任何明显的问题。

很抱歉含糊其辞,但我似乎已经解决了这个问题。我用这个网站上的公式计算p1-p7和c1-c4。([斯特拉森矩阵乘法公式][1]

[1] :)

实施这些公式后,乘积矩阵几乎正确,但数值为4或4。然后我将基本情况更改为当x和y的长度等于2时,这似乎纠正了关闭的4个值。对于那些好奇的人,这里是我为strassen2类修改的代码

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package pkg2a;

/**
 *
 * @author williamnewman
 */
public class strassen2 {

    int [][] strass2(int[][] x, int[][]y){
        if(x.length <= 2 && y.length <= 2){ //!!!! MODIFICATION HERE !!
            return multiply(x,y);

        }
        else{
       int[][] a,b,c,d,e,f,g,h;
       int dim = x.length/2;

      System.out.println("A<B<C");
       a = helpSub(0,0,x);
       //C(a);
       b = helpSub(0,dim,x);

       //C(b);
       c = helpSub(dim,0,x);
       //C(c);
       d = helpSub(dim,dim,x);
       //C(d);
       e = helpSub(0,0,y);
       //C(e);
       f = helpSub(0,dim,y);
       //C(f);
       g = helpSub(dim,0,y);
       //C(g);
       h = helpSub(dim,dim,y);
       //C(h);

       int[][] p1,p2,p3,p4,p5,p6,p7;
      // createSub(x,y,a,b,c,d,e,f,g,h);
      int[] s1,s2,s3,s4,s5,s6,s7,s8,s9,s10; 

      //MODIFICATION HERE
       p1 = strass2(a,subtract(f,h));
       p2 = strass2(add(a,b),h);
       p3 = strass2(add(c,d),e);
       p4 = strass2(d,subtract(g,e));
       p5 = strass2(add(a,d),add(e,h));
       p6 = strass2(subtract(b,d),add(g,h));
       p7 = strass2(subtract(a,c),add(e,f));
       int [][] prod;
       int [][] c1,c2,c3,c4;
       c1 = subtract(add(p5,p4),subtract(p2,p6));
       c2 = add(p1,p2);
       c3 = add(p3,p4);
       c4 = subtract(add(p1,p5),add(p3,p7));
       //C(c1);
       //System.out.println("C1::");
       //C(c2);
       //System.out.println("C2::");
       //C(c3);
       //System.out.println("C3::");
       //C(c4);
       //System.out.println("C4::");
       prod = createProd(c1,c2,c3,c4);
       return prod;

        }




    }

    int[][] createProd(int[][] c1, int[][] c2, int[][] c3, int[][] c4){
        int[][] product = new int[c1.length*2][c1.length*2];
        int mid = c1.length;
        int fin = c1.length * 2;
        helpProduct(0,0,mid,mid,product,c1);
        helpProduct(0,mid,mid,fin,product,c2);
        helpProduct(mid,0,fin,mid,product,c3);
        helpProduct(mid,mid,fin,fin,product,c4);

         System.out.println();
        System.out.println("PRODUCT::!:");
        //C(product);
        return product;



    }

        //Helper function to create larger matrix from submatrices
    void helpProduct(int x, int y, int z1, int z2,int[][] product, int[][] a1){
        int indR = 0;
        int indC = 0;
        for(int i = x; i < z1; i++){
            indC = 0;
            for(int j = y; j < z2; j++){
                product[i][j] = a1[indR][indC];
                indC++;
            }
            indR++;
        }
    }

    /*
        void createSub(int[][]x, int[][]y,int[][] a,int[][] b,int[][] c, int[][] d, int[][] e, int[][] f, int [][] g, int[][] h){
       int div1R = x.length/2;
       int div1C = div1R;
       int div2R = div1R;
       int div2C = div1R;
       a = helpSub(0,0,div1R,div1C,x);
      // c(a);
       b = helpSub(0,div1C,div1R,x[0].length,x);
       //c(b);
       c = helpSub(div1R,0,x.length,div1C,x);
       //c(c);
       d = helpSub(div1R,div1C,x.length,x[0].length,x);
       //c(d);
       e = helpSub(0,0,div2R,div2C,y);
       //c(e);
       f = helpSub(0,div2C,div2R,y[0].length,y);
      // c(f);
       g = helpSub(div2R,0,y.length,div2C,y);
       //c(g);
       h = helpSub(div2R,div2C,y.length,y[0].length,y);
      // c(h);


    }
        */
        int[][] helpSub(int x, int y, int[][] mat){
        int[][] sub = new int[mat.length/2][mat.length/2];
        for(int i1 = 0, i2=x; i1 < (mat.length/2); i1++, i2++)
        for(int j1 = 0, j2=y; j1<(mat.length/2); j1++, j2++)
        {
                sub[i1][j1] = mat[i2][j2];
                               // System.out.println(sub[i1][j1]);
        }
        return sub;
    }


    int[][] multiply(int[][]a,int[][]b){
        MM nu = new MM(a,b);
        return nu.product;
    }

        //Adds one matrix to the next
    int[][] add(int[][]a, int[][]b){
        int [][] nu = new int[a.length][a[0].length];
        for(int i = 0; i < a.length; i++){
            for(int j = 0; j < a[i].length;j++){
                nu[i][j] = a[i][j] + b[i][j];
            }
        }
        return nu;
    }

    //Subtracts second matrix from the first
    int[][] subtract(int[][] a, int[][] b){
        int [][] sub = new int[a.length][a.length];
        //System.out.println("made it");
        int rows = a.length;
        int columns = a[0].length;
        for(int i = 0; i < rows; i++){
            for(int j = 0; j < columns;j++){
                sub[i][j] = a[i][j] - b[i][j];
            }
        }
        return sub;
    }

     void C(int[][] product){
        for(int i = 0; i <product.length; i++){
            for(int j = 0; j < product[i].length; j++){
                System.out.print(product[i][j]  + " ");

            }
            System.out.println();
        }
    }
}
/*
*要更改此许可证标题,请在“项目属性”中选择“许可证标题”。
*要更改此模板文件,请选择工具|模板
*然后在编辑器中打开模板。
*/
包装pkg2a;
/**
*
*@作者威廉纽曼
*/
公共类Strasse2{
int[][]strass2(int[]x,int[]y){

if(x.length很抱歉说得含糊不清,但我似乎已经解决了这个问题。我使用了这个网站上p1-p7和c1-c4的公式。([Strassen矩阵乘法公式][1]

[1] :)

在实现了这些公式之后,乘积矩阵几乎是正确的,但是4或值是关闭的。然后我将基本情况更改为当x和y的长度等于2时,这似乎纠正了4个关闭的值。对于那些好奇的人,这里是我为strassen2类修改的代码

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package pkg2a;

/**
 *
 * @author williamnewman
 */
public class strassen2 {

    int [][] strass2(int[][] x, int[][]y){
        if(x.length <= 2 && y.length <= 2){ //!!!! MODIFICATION HERE !!
            return multiply(x,y);

        }
        else{
       int[][] a,b,c,d,e,f,g,h;
       int dim = x.length/2;

      System.out.println("A<B<C");
       a = helpSub(0,0,x);
       //C(a);
       b = helpSub(0,dim,x);

       //C(b);
       c = helpSub(dim,0,x);
       //C(c);
       d = helpSub(dim,dim,x);
       //C(d);
       e = helpSub(0,0,y);
       //C(e);
       f = helpSub(0,dim,y);
       //C(f);
       g = helpSub(dim,0,y);
       //C(g);
       h = helpSub(dim,dim,y);
       //C(h);

       int[][] p1,p2,p3,p4,p5,p6,p7;
      // createSub(x,y,a,b,c,d,e,f,g,h);
      int[] s1,s2,s3,s4,s5,s6,s7,s8,s9,s10; 

      //MODIFICATION HERE
       p1 = strass2(a,subtract(f,h));
       p2 = strass2(add(a,b),h);
       p3 = strass2(add(c,d),e);
       p4 = strass2(d,subtract(g,e));
       p5 = strass2(add(a,d),add(e,h));
       p6 = strass2(subtract(b,d),add(g,h));
       p7 = strass2(subtract(a,c),add(e,f));
       int [][] prod;
       int [][] c1,c2,c3,c4;
       c1 = subtract(add(p5,p4),subtract(p2,p6));
       c2 = add(p1,p2);
       c3 = add(p3,p4);
       c4 = subtract(add(p1,p5),add(p3,p7));
       //C(c1);
       //System.out.println("C1::");
       //C(c2);
       //System.out.println("C2::");
       //C(c3);
       //System.out.println("C3::");
       //C(c4);
       //System.out.println("C4::");
       prod = createProd(c1,c2,c3,c4);
       return prod;

        }




    }

    int[][] createProd(int[][] c1, int[][] c2, int[][] c3, int[][] c4){
        int[][] product = new int[c1.length*2][c1.length*2];
        int mid = c1.length;
        int fin = c1.length * 2;
        helpProduct(0,0,mid,mid,product,c1);
        helpProduct(0,mid,mid,fin,product,c2);
        helpProduct(mid,0,fin,mid,product,c3);
        helpProduct(mid,mid,fin,fin,product,c4);

         System.out.println();
        System.out.println("PRODUCT::!:");
        //C(product);
        return product;



    }

        //Helper function to create larger matrix from submatrices
    void helpProduct(int x, int y, int z1, int z2,int[][] product, int[][] a1){
        int indR = 0;
        int indC = 0;
        for(int i = x; i < z1; i++){
            indC = 0;
            for(int j = y; j < z2; j++){
                product[i][j] = a1[indR][indC];
                indC++;
            }
            indR++;
        }
    }

    /*
        void createSub(int[][]x, int[][]y,int[][] a,int[][] b,int[][] c, int[][] d, int[][] e, int[][] f, int [][] g, int[][] h){
       int div1R = x.length/2;
       int div1C = div1R;
       int div2R = div1R;
       int div2C = div1R;
       a = helpSub(0,0,div1R,div1C,x);
      // c(a);
       b = helpSub(0,div1C,div1R,x[0].length,x);
       //c(b);
       c = helpSub(div1R,0,x.length,div1C,x);
       //c(c);
       d = helpSub(div1R,div1C,x.length,x[0].length,x);
       //c(d);
       e = helpSub(0,0,div2R,div2C,y);
       //c(e);
       f = helpSub(0,div2C,div2R,y[0].length,y);
      // c(f);
       g = helpSub(div2R,0,y.length,div2C,y);
       //c(g);
       h = helpSub(div2R,div2C,y.length,y[0].length,y);
      // c(h);


    }
        */
        int[][] helpSub(int x, int y, int[][] mat){
        int[][] sub = new int[mat.length/2][mat.length/2];
        for(int i1 = 0, i2=x; i1 < (mat.length/2); i1++, i2++)
        for(int j1 = 0, j2=y; j1<(mat.length/2); j1++, j2++)
        {
                sub[i1][j1] = mat[i2][j2];
                               // System.out.println(sub[i1][j1]);
        }
        return sub;
    }


    int[][] multiply(int[][]a,int[][]b){
        MM nu = new MM(a,b);
        return nu.product;
    }

        //Adds one matrix to the next
    int[][] add(int[][]a, int[][]b){
        int [][] nu = new int[a.length][a[0].length];
        for(int i = 0; i < a.length; i++){
            for(int j = 0; j < a[i].length;j++){
                nu[i][j] = a[i][j] + b[i][j];
            }
        }
        return nu;
    }

    //Subtracts second matrix from the first
    int[][] subtract(int[][] a, int[][] b){
        int [][] sub = new int[a.length][a.length];
        //System.out.println("made it");
        int rows = a.length;
        int columns = a[0].length;
        for(int i = 0; i < rows; i++){
            for(int j = 0; j < columns;j++){
                sub[i][j] = a[i][j] - b[i][j];
            }
        }
        return sub;
    }

     void C(int[][] product){
        for(int i = 0; i <product.length; i++){
            for(int j = 0; j < product[i].length; j++){
                System.out.print(product[i][j]  + " ");

            }
            System.out.println();
        }
    }
}
/*
*要更改此许可证标题,请在“项目属性”中选择“许可证标题”。
*要更改此模板文件,请选择工具|模板
*然后在编辑器中打开模板。
*/
包装pkg2a;
/**
*
*@作者威廉纽曼
*/
公共类Strasse2{
int[][]strass2(int[]x,int[]y){
如果(x.length你可能想提一下什么是不起作用的。当然你有一些测试,例如,将两个4x4单位矩阵相乘不会得到单位矩阵,或者其他任何东西。如果是这样,请分享它,让你的例子完整且可运行,提及输入、预期结果和你得到的不正确结果。你还希望如何做如果有人在不运行代码的情况下完整地阅读您的代码,就可以获得帮助?这是一种更难的方法,它并不是有史以来最容易阅读的方法。您可能想提及哪些方法不起作用。当然,您有一些测试,例如,将两个4x4单位矩阵相乘并不能得到单位矩阵,或者其他任何东西。如果是这样,请分享,mak如果您的示例完整且可运行,请提及输入、预期结果和您得到的不正确结果。如果有人阅读您的完整代码而不尝试运行它,您还希望得到什么帮助?这是一种更难的方法,它并不是有史以来最具可读性的方法。
/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package pkg2a;

/**
 *
 * @author williamnewman
 */
public class strassen2 {

    int [][] strass2(int[][] x, int[][]y){
        if(x.length <= 2 && y.length <= 2){ //!!!! MODIFICATION HERE !!
            return multiply(x,y);

        }
        else{
       int[][] a,b,c,d,e,f,g,h;
       int dim = x.length/2;

      System.out.println("A<B<C");
       a = helpSub(0,0,x);
       //C(a);
       b = helpSub(0,dim,x);

       //C(b);
       c = helpSub(dim,0,x);
       //C(c);
       d = helpSub(dim,dim,x);
       //C(d);
       e = helpSub(0,0,y);
       //C(e);
       f = helpSub(0,dim,y);
       //C(f);
       g = helpSub(dim,0,y);
       //C(g);
       h = helpSub(dim,dim,y);
       //C(h);

       int[][] p1,p2,p3,p4,p5,p6,p7;
      // createSub(x,y,a,b,c,d,e,f,g,h);
      int[] s1,s2,s3,s4,s5,s6,s7,s8,s9,s10; 

      //MODIFICATION HERE
       p1 = strass2(a,subtract(f,h));
       p2 = strass2(add(a,b),h);
       p3 = strass2(add(c,d),e);
       p4 = strass2(d,subtract(g,e));
       p5 = strass2(add(a,d),add(e,h));
       p6 = strass2(subtract(b,d),add(g,h));
       p7 = strass2(subtract(a,c),add(e,f));
       int [][] prod;
       int [][] c1,c2,c3,c4;
       c1 = subtract(add(p5,p4),subtract(p2,p6));
       c2 = add(p1,p2);
       c3 = add(p3,p4);
       c4 = subtract(add(p1,p5),add(p3,p7));
       //C(c1);
       //System.out.println("C1::");
       //C(c2);
       //System.out.println("C2::");
       //C(c3);
       //System.out.println("C3::");
       //C(c4);
       //System.out.println("C4::");
       prod = createProd(c1,c2,c3,c4);
       return prod;

        }




    }

    int[][] createProd(int[][] c1, int[][] c2, int[][] c3, int[][] c4){
        int[][] product = new int[c1.length*2][c1.length*2];
        int mid = c1.length;
        int fin = c1.length * 2;
        helpProduct(0,0,mid,mid,product,c1);
        helpProduct(0,mid,mid,fin,product,c2);
        helpProduct(mid,0,fin,mid,product,c3);
        helpProduct(mid,mid,fin,fin,product,c4);

         System.out.println();
        System.out.println("PRODUCT::!:");
        //C(product);
        return product;



    }

        //Helper function to create larger matrix from submatrices
    void helpProduct(int x, int y, int z1, int z2,int[][] product, int[][] a1){
        int indR = 0;
        int indC = 0;
        for(int i = x; i < z1; i++){
            indC = 0;
            for(int j = y; j < z2; j++){
                product[i][j] = a1[indR][indC];
                indC++;
            }
            indR++;
        }
    }

    /*
        void createSub(int[][]x, int[][]y,int[][] a,int[][] b,int[][] c, int[][] d, int[][] e, int[][] f, int [][] g, int[][] h){
       int div1R = x.length/2;
       int div1C = div1R;
       int div2R = div1R;
       int div2C = div1R;
       a = helpSub(0,0,div1R,div1C,x);
      // c(a);
       b = helpSub(0,div1C,div1R,x[0].length,x);
       //c(b);
       c = helpSub(div1R,0,x.length,div1C,x);
       //c(c);
       d = helpSub(div1R,div1C,x.length,x[0].length,x);
       //c(d);
       e = helpSub(0,0,div2R,div2C,y);
       //c(e);
       f = helpSub(0,div2C,div2R,y[0].length,y);
      // c(f);
       g = helpSub(div2R,0,y.length,div2C,y);
       //c(g);
       h = helpSub(div2R,div2C,y.length,y[0].length,y);
      // c(h);


    }
        */
        int[][] helpSub(int x, int y, int[][] mat){
        int[][] sub = new int[mat.length/2][mat.length/2];
        for(int i1 = 0, i2=x; i1 < (mat.length/2); i1++, i2++)
        for(int j1 = 0, j2=y; j1<(mat.length/2); j1++, j2++)
        {
                sub[i1][j1] = mat[i2][j2];
                               // System.out.println(sub[i1][j1]);
        }
        return sub;
    }


    int[][] multiply(int[][]a,int[][]b){
        MM nu = new MM(a,b);
        return nu.product;
    }

        //Adds one matrix to the next
    int[][] add(int[][]a, int[][]b){
        int [][] nu = new int[a.length][a[0].length];
        for(int i = 0; i < a.length; i++){
            for(int j = 0; j < a[i].length;j++){
                nu[i][j] = a[i][j] + b[i][j];
            }
        }
        return nu;
    }

    //Subtracts second matrix from the first
    int[][] subtract(int[][] a, int[][] b){
        int [][] sub = new int[a.length][a.length];
        //System.out.println("made it");
        int rows = a.length;
        int columns = a[0].length;
        for(int i = 0; i < rows; i++){
            for(int j = 0; j < columns;j++){
                sub[i][j] = a[i][j] - b[i][j];
            }
        }
        return sub;
    }

     void C(int[][] product){
        for(int i = 0; i <product.length; i++){
            for(int j = 0; j < product[i].length; j++){
                System.out.print(product[i][j]  + " ");

            }
            System.out.println();
        }
    }
}