高效4x4矩阵乘法(C与汇编)

高效4x4矩阵乘法(C与汇编),c,optimization,assembly,sse,matrix-multiplication,C,Optimization,Assembly,Sse,Matrix Multiplication,我正在寻找一种更快、更复杂的方法,用C语言乘以两个4x4矩阵。我目前的研究重点是使用SIMD扩展的x86-64汇编。到目前为止,我已经创建了一个比简单的C实现快6倍的函数,这超出了我对性能改进的预期。不幸的是,只有当编译没有使用优化标志时(GCC4.7),这种情况才会发生。使用-O2,C变得更快,我的努力变得毫无意义 我知道,现代编译器使用复杂的优化技术来实现几乎完美的代码,通常比一个巧妙的手工组装快。但在少数性能关键的情况下,人类可能会试图与编译器争夺时钟周期。特别是,可以探索一些有现代ISA

我正在寻找一种更快、更复杂的方法,用C语言乘以两个4x4矩阵。我目前的研究重点是使用SIMD扩展的x86-64汇编。到目前为止,我已经创建了一个比简单的C实现快6倍的函数,这超出了我对性能改进的预期。不幸的是,只有当编译没有使用优化标志时(GCC4.7),这种情况才会发生。使用
-O2
,C变得更快,我的努力变得毫无意义

我知道,现代编译器使用复杂的优化技术来实现几乎完美的代码,通常比一个巧妙的手工组装快。但在少数性能关键的情况下,人类可能会试图与编译器争夺时钟周期。特别是,可以探索一些有现代ISA支持的数学(就像我的例子一样)

我的函数如下所示(AT&T语法,GNU汇编程序):

它通过处理128位SSE寄存器中的四个浮点数,每次迭代计算一整列结果矩阵。通过一点数学(操作重新排序和聚合)和用于并行乘法/加法4xfloat包的
mullps
/
addps
指令,完全矢量化是可能的。代码重用用于传递参数的寄存器(
%rdi
%rsi
%rdx
:GNU/Linux ABI),受益于(内部)循环展开,并将一个矩阵完全保存在XMM寄存器中以减少内存读取。A你可以看到,我已经研究了这个主题,并花了时间尽我所能实现它

征服我的代码的天真C计算如下所示:

void matrixMultiplyNormal(mat4_t *mat_a, mat4_t *mat_b, mat4_t *mat_r) {
    for (unsigned int i = 0; i < 16; i += 4)
        for (unsigned int j = 0; j < 4; ++j)
            mat_r->m[i + j] = (mat_b->m[i + 0] * mat_a->m[j +  0])
                            + (mat_b->m[i + 1] * mat_a->m[j +  4])
                            + (mat_b->m[i + 2] * mat_a->m[j +  8])
                            + (mat_b->m[i + 3] * mat_a->m[j + 12]);
}
void matrixMultiplyNormal(mat4\u t*mat\u a、mat4\u t*mat\u b、mat4\u t*mat\r){
for(无符号整数i=0;i<16;i+=4)
对于(无符号整数j=0;j<4;++j)
材料r->m[i+j]=(材料b->m[i+0]*材料a->m[j+0])
+(材料b->m[i+1]*材料a->m[j+4])
+(材料b->m[i+2]*材料a->m[j+8])
+(材料b->m[i+3]*材料a->m[j+12]);
}

我已经研究了上面的C代码的优化汇编输出,它在XMM寄存器中存储浮点数时,不涉及任何并行操作——只涉及标量计算、指针算术和条件跳转。编译器的代码似乎不那么刻意,但它仍然比我的矢量化版本稍微有效一些,预计速度要快4倍左右。我相信总体思路是正确的——程序员做类似的事情,结果是有回报的。但这里出了什么问题?是否有我不知道的寄存器分配或指令调度问题?您知道任何x86-64汇编工具或技巧来支持我与机器的战斗吗?

我想知道转置其中一个矩阵是否有益

考虑我们如何将以下两个矩阵相乘

A1 A2 A3 A4        W1 W2 W3 W4
B1 B2 B3 B4        X1 X2 X3 X4
C1 C2 C3 C4    *   Y1 Y2 Y3 Y4
D1 D2 D3 D4        Z1 Z2 Z3 Z4
这将导致

dot(A,?1) dot(A,?2) dot(A,?3) dot(A,?4)
dot(B,?1) dot(B,?2) dot(B,?3) dot(B,?4)
dot(C,?1) dot(C,?2) dot(C,?3) dot(C,?4)
dot(D,?1) dot(D,?2) dot(D,?3) dot(D,?4)
做一行和一列的点积是一件痛苦的事

如果我们在相乘之前转置第二个矩阵呢

A1 A2 A3 A4        W1 X1 Y1 Z1
B1 B2 B3 B4        W2 X2 Y2 Z2
C1 C2 C3 C4    *   W3 X3 Y3 Z3
D1 D2 D3 D4        W4 X4 Y4 Z4
现在我们做的是两行的点积,而不是一行和一列的点积。这有助于更好地使用SIMD指令


希望这有帮助。

4x4矩阵乘法是64次乘法和48次加法。使用SSE,这可以减少到16次乘法和12次加法(以及16次广播)。下面的代码将为您执行此操作。它只需要SSE(
#include
)。数组
A
B
C
需要16字节对齐。使用水平指令,例如
hadd
(SSE3)和
dpps
(SSE4.1)将是(尤其是
dpps
)。我不知道循环展开是否有帮助

void M4x4_SSE(float *A, float *B, float *C) {
    __m128 row1 = _mm_load_ps(&B[0]);
    __m128 row2 = _mm_load_ps(&B[4]);
    __m128 row3 = _mm_load_ps(&B[8]);
    __m128 row4 = _mm_load_ps(&B[12]);
    for(int i=0; i<4; i++) {
        __m128 brod1 = _mm_set1_ps(A[4*i + 0]);
        __m128 brod2 = _mm_set1_ps(A[4*i + 1]);
        __m128 brod3 = _mm_set1_ps(A[4*i + 2]);
        __m128 brod4 = _mm_set1_ps(A[4*i + 3]);
        __m128 row = _mm_add_ps(
                    _mm_add_ps(
                        _mm_mul_ps(brod1, row1),
                        _mm_mul_ps(brod2, row2)),
                    _mm_add_ps(
                        _mm_mul_ps(brod3, row3),
                        _mm_mul_ps(brod4, row4)));
        _mm_store_ps(&C[4*i], row);
    }
}
void M4x4_SSE(浮点*A、浮点*B、浮点*C){
__m128第1行=_mm_load_ps(&B[0]);
__m128第2行=_mm_load_ps(&B[4]);
__m128第3行=_mm_load_ps(&B[8]);
__m128第4行=_mm_load_ps(&B[12]);

对于(int i=0;i有一种方法可以加速代码并超越编译器。它不涉及任何复杂的管道分析或深层代码微观优化(这并不意味着它不能从中进一步受益)。优化使用三个简单的技巧:

  • 该函数现在是32字节对齐的(这大大提高了性能)

  • 主回路反向运行,从而将比较减少到零测试(基于EFLAGS)

  • 指令级地址算法被证明比“外部”指针计算更快(即使它需要两倍于«在3/4情况下»的添加量)。它将循环体缩短了四条指令,并减少了执行路径中的数据依赖性

  • 此外,代码使用相对跳转语法来抑制符号重新定义错误,当GCC尝试内联符号重新定义错误时(在放入
    asm
    语句并使用
    -O3
    编译后)会发生这种错误


    这是到目前为止我所看到的最快的x86-64实现。我将感谢、投票并接受任何为此提供更快组装件的答案!

    显然,您可以一次从四个矩阵中提取术语,并使用相同的算法同时将四个矩阵相乘。

    上面的Sandy Bridge扩展了说明支持8元矢量算法的离子集。考虑此实现。< /P>
    struct MATRIX {
        union {
            float  f[4][4];
            __m128 m[4];
            __m256 n[2];
        };
    };
    MATRIX myMultiply(MATRIX M1, MATRIX M2) {
        // Perform a 4x4 matrix multiply by a 4x4 matrix 
        // Be sure to run in 64 bit mode and set right flags
        // Properties, C/C++, Enable Enhanced Instruction, /arch:AVX 
        // Having MATRIX on a 32 byte bundry does help performance
        MATRIX mResult;
        __m256 a0, a1, b0, b1;
        __m256 c0, c1, c2, c3, c4, c5, c6, c7;
        __m256 t0, t1, u0, u1;
    
        t0 = M1.n[0];                                                   // t0 = a00, a01, a02, a03, a10, a11, a12, a13
        t1 = M1.n[1];                                                   // t1 = a20, a21, a22, a23, a30, a31, a32, a33
        u0 = M2.n[0];                                                   // u0 = b00, b01, b02, b03, b10, b11, b12, b13
        u1 = M2.n[1];                                                   // u1 = b20, b21, b22, b23, b30, b31, b32, b33
    
        a0 = _mm256_shuffle_ps(t0, t0, _MM_SHUFFLE(0, 0, 0, 0));        // a0 = a00, a00, a00, a00, a10, a10, a10, a10
        a1 = _mm256_shuffle_ps(t1, t1, _MM_SHUFFLE(0, 0, 0, 0));        // a1 = a20, a20, a20, a20, a30, a30, a30, a30
        b0 = _mm256_permute2f128_ps(u0, u0, 0x00);                      // b0 = b00, b01, b02, b03, b00, b01, b02, b03  
        c0 = _mm256_mul_ps(a0, b0);                                     // c0 = a00*b00  a00*b01  a00*b02  a00*b03  a10*b00  a10*b01  a10*b02  a10*b03
        c1 = _mm256_mul_ps(a1, b0);                                     // c1 = a20*b00  a20*b01  a20*b02  a20*b03  a30*b00  a30*b01  a30*b02  a30*b03
    
        a0 = _mm256_shuffle_ps(t0, t0, _MM_SHUFFLE(1, 1, 1, 1));        // a0 = a01, a01, a01, a01, a11, a11, a11, a11
        a1 = _mm256_shuffle_ps(t1, t1, _MM_SHUFFLE(1, 1, 1, 1));        // a1 = a21, a21, a21, a21, a31, a31, a31, a31
        b0 = _mm256_permute2f128_ps(u0, u0, 0x11);                      // b0 = b10, b11, b12, b13, b10, b11, b12, b13
        c2 = _mm256_mul_ps(a0, b0);                                     // c2 = a01*b10  a01*b11  a01*b12  a01*b13  a11*b10  a11*b11  a11*b12  a11*b13
        c3 = _mm256_mul_ps(a1, b0);                                     // c3 = a21*b10  a21*b11  a21*b12  a21*b13  a31*b10  a31*b11  a31*b12  a31*b13
    
        a0 = _mm256_shuffle_ps(t0, t0, _MM_SHUFFLE(2, 2, 2, 2));        // a0 = a02, a02, a02, a02, a12, a12, a12, a12
        a1 = _mm256_shuffle_ps(t1, t1, _MM_SHUFFLE(2, 2, 2, 2));        // a1 = a22, a22, a22, a22, a32, a32, a32, a32
        b1 = _mm256_permute2f128_ps(u1, u1, 0x00);                      // b0 = b20, b21, b22, b23, b20, b21, b22, b23
        c4 = _mm256_mul_ps(a0, b1);                                     // c4 = a02*b20  a02*b21  a02*b22  a02*b23  a12*b20  a12*b21  a12*b22  a12*b23
        c5 = _mm256_mul_ps(a1, b1);                                     // c5 = a22*b20  a22*b21  a22*b22  a22*b23  a32*b20  a32*b21  a32*b22  a32*b23
    
        a0 = _mm256_shuffle_ps(t0, t0, _MM_SHUFFLE(3, 3, 3, 3));        // a0 = a03, a03, a03, a03, a13, a13, a13, a13
        a1 = _mm256_shuffle_ps(t1, t1, _MM_SHUFFLE(3, 3, 3, 3));        // a1 = a23, a23, a23, a23, a33, a33, a33, a33
        b1 = _mm256_permute2f128_ps(u1, u1, 0x11);                      // b0 = b30, b31, b32, b33, b30, b31, b32, b33
        c6 = _mm256_mul_ps(a0, b1);                                     // c6 = a03*b30  a03*b31  a03*b32  a03*b33  a13*b30  a13*b31  a13*b32  a13*b33
        c7 = _mm256_mul_ps(a1, b1);                                     // c7 = a23*b30  a23*b31  a23*b32  a23*b33  a33*b30  a33*b31  a33*b32  a33*b33
    
        c0 = _mm256_add_ps(c0, c2);                                     // c0 = c0 + c2 (two terms, first two rows)
        c4 = _mm256_add_ps(c4, c6);                                     // c4 = c4 + c6 (the other two terms, first two rows)
        c1 = _mm256_add_ps(c1, c3);                                     // c1 = c1 + c3 (two terms, second two rows)
        c5 = _mm256_add_ps(c5, c7);                                     // c5 = c5 + c7 (the other two terms, second two rose)
    
                                                                        // Finally complete addition of all four terms and return the results
        mResult.n[0] = _mm256_add_ps(c0, c4);       // n0 = a00*b00+a01*b10+a02*b20+a03*b30  a00*b01+a01*b11+a02*b21+a03*b31  a00*b02+a01*b12+a02*b22+a03*b32  a00*b03+a01*b13+a02*b23+a03*b33
                                                    //      a10*b00+a11*b10+a12*b20+a13*b30  a10*b01+a11*b11+a12*b21+a13*b31  a10*b02+a11*b12+a12*b22+a13*b32  a10*b03+a11*b13+a12*b23+a13*b33
        mResult.n[1] = _mm256_add_ps(c1, c5);       // n1 = a20*b00+a21*b10+a22*b20+a23*b30  a20*b01+a21*b11+a22*b21+a23*b31  a20*b02+a21*b12+a22*b22+a23*b32  a20*b03+a21*b13+a22*b23+a23*b33
                                                    //      a30*b00+a31*b10+a32*b20+a33*b30  a30*b01+a31*b11+a32*b21+a33*b31  a30*b02+a31*b12+a32*b22+a33*b32  a30*b03+a31*b13+a32*b23+a33*b33
        return mResult;
    }
    

    最近的编译器可以比人类更好地进行微优化。专注于算法优化!这正是我所做的——我使用了另一种计算方法来适应SSE问题。它实际上是一种不同的算法。问题可能是,现在我还必须在指令级对其进行优化,因为,在专注于算法的同时嗯,我可能引入了数据依赖性问题、无效的内存访问模式或其他一些黑魔法。你可能会更好
        .text
        .align 32                           # 1. function entry alignment
        .globl matrixMultiplyASM            #    (for a faster call)
        .type matrixMultiplyASM, @function
    matrixMultiplyASM:
        movaps   (%rdi), %xmm0
        movaps 16(%rdi), %xmm1
        movaps 32(%rdi), %xmm2
        movaps 48(%rdi), %xmm3
        movq $48, %rcx                      # 2. loop reversal
    1:                                      #    (for simpler exit condition)
        movss (%rsi, %rcx), %xmm4           # 3. extended address operands
        shufps $0, %xmm4, %xmm4             #    (faster than pointer calculation)
        mulps %xmm0, %xmm4
        movaps %xmm4, %xmm5
        movss 4(%rsi, %rcx), %xmm4
        shufps $0, %xmm4, %xmm4
        mulps %xmm1, %xmm4
        addps %xmm4, %xmm5
        movss 8(%rsi, %rcx), %xmm4
        shufps $0, %xmm4, %xmm4
        mulps %xmm2, %xmm4
        addps %xmm4, %xmm5
        movss 12(%rsi, %rcx), %xmm4
        shufps $0, %xmm4, %xmm4
        mulps %xmm3, %xmm4
        addps %xmm4, %xmm5
        movaps %xmm5, (%rdx, %rcx)
        subq $16, %rcx                      # one 'sub' (vs 'add' & 'cmp')
        jge 1b                              # SF=OF, idiom: jump if positive
        ret
    
    struct MATRIX {
        union {
            float  f[4][4];
            __m128 m[4];
            __m256 n[2];
        };
    };
    MATRIX myMultiply(MATRIX M1, MATRIX M2) {
        // Perform a 4x4 matrix multiply by a 4x4 matrix 
        // Be sure to run in 64 bit mode and set right flags
        // Properties, C/C++, Enable Enhanced Instruction, /arch:AVX 
        // Having MATRIX on a 32 byte bundry does help performance
        MATRIX mResult;
        __m256 a0, a1, b0, b1;
        __m256 c0, c1, c2, c3, c4, c5, c6, c7;
        __m256 t0, t1, u0, u1;
    
        t0 = M1.n[0];                                                   // t0 = a00, a01, a02, a03, a10, a11, a12, a13
        t1 = M1.n[1];                                                   // t1 = a20, a21, a22, a23, a30, a31, a32, a33
        u0 = M2.n[0];                                                   // u0 = b00, b01, b02, b03, b10, b11, b12, b13
        u1 = M2.n[1];                                                   // u1 = b20, b21, b22, b23, b30, b31, b32, b33
    
        a0 = _mm256_shuffle_ps(t0, t0, _MM_SHUFFLE(0, 0, 0, 0));        // a0 = a00, a00, a00, a00, a10, a10, a10, a10
        a1 = _mm256_shuffle_ps(t1, t1, _MM_SHUFFLE(0, 0, 0, 0));        // a1 = a20, a20, a20, a20, a30, a30, a30, a30
        b0 = _mm256_permute2f128_ps(u0, u0, 0x00);                      // b0 = b00, b01, b02, b03, b00, b01, b02, b03  
        c0 = _mm256_mul_ps(a0, b0);                                     // c0 = a00*b00  a00*b01  a00*b02  a00*b03  a10*b00  a10*b01  a10*b02  a10*b03
        c1 = _mm256_mul_ps(a1, b0);                                     // c1 = a20*b00  a20*b01  a20*b02  a20*b03  a30*b00  a30*b01  a30*b02  a30*b03
    
        a0 = _mm256_shuffle_ps(t0, t0, _MM_SHUFFLE(1, 1, 1, 1));        // a0 = a01, a01, a01, a01, a11, a11, a11, a11
        a1 = _mm256_shuffle_ps(t1, t1, _MM_SHUFFLE(1, 1, 1, 1));        // a1 = a21, a21, a21, a21, a31, a31, a31, a31
        b0 = _mm256_permute2f128_ps(u0, u0, 0x11);                      // b0 = b10, b11, b12, b13, b10, b11, b12, b13
        c2 = _mm256_mul_ps(a0, b0);                                     // c2 = a01*b10  a01*b11  a01*b12  a01*b13  a11*b10  a11*b11  a11*b12  a11*b13
        c3 = _mm256_mul_ps(a1, b0);                                     // c3 = a21*b10  a21*b11  a21*b12  a21*b13  a31*b10  a31*b11  a31*b12  a31*b13
    
        a0 = _mm256_shuffle_ps(t0, t0, _MM_SHUFFLE(2, 2, 2, 2));        // a0 = a02, a02, a02, a02, a12, a12, a12, a12
        a1 = _mm256_shuffle_ps(t1, t1, _MM_SHUFFLE(2, 2, 2, 2));        // a1 = a22, a22, a22, a22, a32, a32, a32, a32
        b1 = _mm256_permute2f128_ps(u1, u1, 0x00);                      // b0 = b20, b21, b22, b23, b20, b21, b22, b23
        c4 = _mm256_mul_ps(a0, b1);                                     // c4 = a02*b20  a02*b21  a02*b22  a02*b23  a12*b20  a12*b21  a12*b22  a12*b23
        c5 = _mm256_mul_ps(a1, b1);                                     // c5 = a22*b20  a22*b21  a22*b22  a22*b23  a32*b20  a32*b21  a32*b22  a32*b23
    
        a0 = _mm256_shuffle_ps(t0, t0, _MM_SHUFFLE(3, 3, 3, 3));        // a0 = a03, a03, a03, a03, a13, a13, a13, a13
        a1 = _mm256_shuffle_ps(t1, t1, _MM_SHUFFLE(3, 3, 3, 3));        // a1 = a23, a23, a23, a23, a33, a33, a33, a33
        b1 = _mm256_permute2f128_ps(u1, u1, 0x11);                      // b0 = b30, b31, b32, b33, b30, b31, b32, b33
        c6 = _mm256_mul_ps(a0, b1);                                     // c6 = a03*b30  a03*b31  a03*b32  a03*b33  a13*b30  a13*b31  a13*b32  a13*b33
        c7 = _mm256_mul_ps(a1, b1);                                     // c7 = a23*b30  a23*b31  a23*b32  a23*b33  a33*b30  a33*b31  a33*b32  a33*b33
    
        c0 = _mm256_add_ps(c0, c2);                                     // c0 = c0 + c2 (two terms, first two rows)
        c4 = _mm256_add_ps(c4, c6);                                     // c4 = c4 + c6 (the other two terms, first two rows)
        c1 = _mm256_add_ps(c1, c3);                                     // c1 = c1 + c3 (two terms, second two rows)
        c5 = _mm256_add_ps(c5, c7);                                     // c5 = c5 + c7 (the other two terms, second two rose)
    
                                                                        // Finally complete addition of all four terms and return the results
        mResult.n[0] = _mm256_add_ps(c0, c4);       // n0 = a00*b00+a01*b10+a02*b20+a03*b30  a00*b01+a01*b11+a02*b21+a03*b31  a00*b02+a01*b12+a02*b22+a03*b32  a00*b03+a01*b13+a02*b23+a03*b33
                                                    //      a10*b00+a11*b10+a12*b20+a13*b30  a10*b01+a11*b11+a12*b21+a13*b31  a10*b02+a11*b12+a12*b22+a13*b32  a10*b03+a11*b13+a12*b23+a13*b33
        mResult.n[1] = _mm256_add_ps(c1, c5);       // n1 = a20*b00+a21*b10+a22*b20+a23*b30  a20*b01+a21*b11+a22*b21+a23*b31  a20*b02+a21*b12+a22*b22+a23*b32  a20*b03+a21*b13+a22*b23+a23*b33
                                                    //      a30*b00+a31*b10+a32*b20+a33*b30  a30*b01+a31*b11+a32*b21+a33*b31  a30*b02+a31*b12+a32*b22+a33*b32  a30*b03+a31*b13+a32*b23+a33*b33
        return mResult;
    }