MATLAB优化diag(A*B)吗?

MATLAB优化diag(A*B)吗?,matlab,matrix,linear-algebra,Matlab,Matrix,Linear Algebra,假设我有两个非常大的矩阵A(M-by-N)和B(N-by-M)。我需要A*B的对角线。计算完整的A*B需要M*M*N乘法,而计算它的对角线只需要M*N乘法,因为不需要计算最终会超出对角线的元素 MATLAB是否实现了这一点并自动优化了diag(A*B),或者在这种情况下使用for循环更好?是的,这是for循环更好的罕见情况之一 我通过探查器运行了以下脚本: M = 5000; N = 5000; A = rand(M, N); B = rand(N, M); product = A*B; di

假设我有两个非常大的矩阵
A
(M-by-N)和
B
(N-by-M)。我需要A*B的对角线。计算完整的
A*B
需要M*M*N乘法,而计算它的对角线只需要M*N乘法,因为不需要计算最终会超出对角线的元素


MATLAB是否实现了这一点并自动优化了diag(A*B),或者在这种情况下使用for循环更好?

是的,这是for循环更好的罕见情况之一

我通过探查器运行了以下脚本:

M = 5000;
N = 5000;

A = rand(M, N); B = rand(N, M);
product = A*B;
diag1 = diag(product);

A = rand(M, N); B = rand(N, M);
diag2 = diag(A*B);

A = rand(M, N); B = rand(N, M);
diag3 = zeros(M,1);
for i=1:M
    diag3(i) = A(i,:) * B(:,i);
end
我在每次测试之间重置A和B,以防MATLAB试图通过缓存来加快速度

结果(为简洁起见编辑):

时间呼叫线路
6.29 1 5产品=A*B;
<0.01 1 6 diag1=diag(产品);
5.46 1 9 diag2=diag(A*B);
1 12 diag3=零(M,1);
1 13表示i=1:M
0.52 5000 14 diag3(i)=A(i,:)*B(:,i);
<0.01 5000 15结束
我们可以看到,在这种情况下,for循环变量比其他两个变量快一个数量级。虽然
diag(A*B)
变型实际上比
diag(product)
变型快,但它充其量只是一个边缘

我尝试了一些不同的M和N值,在我的测试中,for循环变量只有在M=1时才会变慢。

实际上,您可以使用以下神奇功能比
for
循环更快:

对于大型矩阵(2000-by-2000和更大),这大约是我的机器上显式
for
循环的两倍,对于大于500-by-500的矩阵,速度更快


请注意,由于求和和和乘法的顺序不同,所有这些方法都会产生数值上不同的结果。

您可以只计算对角线元素而不使用循环:只需使用

sum(A.'.*B).'


还可以将
diag(A*B)
实现为
sum(A.*B',2)
。让我们将其与针对这个问题建议的所有其他实现/解决方案一起进行基准测试

以下列出了作为功能实现的不同方法,用于基准测试:

  • 和乘法-1

    function out = sum_mult_method1(A,B)
    
    out = sum(A.*B',2);
    
  • 和乘法-2

    function out = sum_mult_method2(A,B)
    
    out = sum(A.'.*B).';
    
  • For-loop方法

    function out = for_loop_method(A,B)
    
    M = size(A,1);
    out = zeros(M,1);
    for i=1:M
        out(i) = A(i,:) * B(:,i);
    end
    
  • 全/直接乘法法

    function out = direct_mult_method(A,B)
    
    out = diag(A*B);
    
  • Bsxfun方法

    function out = bsxfun_method(A,B)
    
    out = sum(bsxfun(@times,A,B.'),2);
    
  • 基准测试代码

    num_runs = 1000;
    M_arr = [100 200 500 1000];
    N = 4;
    
    %// Warm up tic/toc.
    tic();
    elapsed = toc();
    tic();
    elapsed = toc();
    
    for k2 = 1:numel(M_arr)
        M = M_arr(k2);
    
        fprintf('\n')
        disp(strcat('*** Benchmarking sizes are M =',num2str(M),' and N = ',num2str(N)));
    
        A = randi(9,M,N);
        B = randi(9,N,M);
    
        disp('1. Sum-multiplication method-1');
        tic
        for k = 1:num_runs
            out1 = sum_mult_method1(A,B);
        end
        toc
        clear out1
    
        disp('2. Sum-multiplication method-2');
        tic
        for k = 1:num_runs
            out2 = sum_mult_method2(A,B);
        end
        toc
        clear out2
    
        disp('3. For-loop method');
        tic
        for k = 1:num_runs
            out3 = for_loop_method(A,B);
        end
        toc
        clear out3
    
        disp('4. Direct-multiplication method');
        tic
        for k = 1:num_runs
            out4 = direct_mult_method(A,B);
        end
        toc
        clear out4
    
        disp('5. Bsxfun method');
        tic
        for k = 1:num_runs
            out5 = bsxfun_method(A,B);
        end
        toc
        clear out5
    
    end
    
    结果

    *** Benchmarking sizes are M =100 and N =4
    1. Sum-multiplication method-1
    Elapsed time is 0.015242 seconds.
    2. Sum-multiplication method-2
    Elapsed time is 0.015180 seconds.
    3. For-loop method
    Elapsed time is 0.192021 seconds.
    4. Direct-multiplication method
    Elapsed time is 0.065543 seconds.
    5. Bsxfun method
    Elapsed time is 0.054149 seconds.
    
    *** Benchmarking sizes are M =200 and N =4
    1. Sum-multiplication method-1
    Elapsed time is 0.009138 seconds.
    2. Sum-multiplication method-2
    Elapsed time is 0.009428 seconds.
    3. For-loop method
    Elapsed time is 0.435735 seconds.
    4. Direct-multiplication method
    Elapsed time is 0.148908 seconds.
    5. Bsxfun method
    Elapsed time is 0.030946 seconds.
    
    *** Benchmarking sizes are M =500 and N =4
    1. Sum-multiplication method-1
    Elapsed time is 0.033287 seconds.
    2. Sum-multiplication method-2
    Elapsed time is 0.026405 seconds.
    3. For-loop method
    Elapsed time is 0.965260 seconds.
    4. Direct-multiplication method
    Elapsed time is 2.832855 seconds.
    5. Bsxfun method
    Elapsed time is 0.034923 seconds.
    
    *** Benchmarking sizes are M =1000 and N =4
    1. Sum-multiplication method-1
    Elapsed time is 0.026068 seconds.
    2. Sum-multiplication method-2
    Elapsed time is 0.032850 seconds.
    3. For-loop method
    Elapsed time is 1.775382 seconds.
    4. Direct-multiplication method
    Elapsed time is 13.764870 seconds.
    5. Bsxfun method
    Elapsed time is 0.044931 seconds.
    
    中间结论

    *** Benchmarking sizes are M =100 and N =4
    1. Sum-multiplication method-1
    Elapsed time is 0.015242 seconds.
    2. Sum-multiplication method-2
    Elapsed time is 0.015180 seconds.
    3. For-loop method
    Elapsed time is 0.192021 seconds.
    4. Direct-multiplication method
    Elapsed time is 0.065543 seconds.
    5. Bsxfun method
    Elapsed time is 0.054149 seconds.
    
    *** Benchmarking sizes are M =200 and N =4
    1. Sum-multiplication method-1
    Elapsed time is 0.009138 seconds.
    2. Sum-multiplication method-2
    Elapsed time is 0.009428 seconds.
    3. For-loop method
    Elapsed time is 0.435735 seconds.
    4. Direct-multiplication method
    Elapsed time is 0.148908 seconds.
    5. Bsxfun method
    Elapsed time is 0.030946 seconds.
    
    *** Benchmarking sizes are M =500 and N =4
    1. Sum-multiplication method-1
    Elapsed time is 0.033287 seconds.
    2. Sum-multiplication method-2
    Elapsed time is 0.026405 seconds.
    3. For-loop method
    Elapsed time is 0.965260 seconds.
    4. Direct-multiplication method
    Elapsed time is 2.832855 seconds.
    5. Bsxfun method
    Elapsed time is 0.034923 seconds.
    
    *** Benchmarking sizes are M =1000 and N =4
    1. Sum-multiplication method-1
    Elapsed time is 0.026068 seconds.
    2. Sum-multiplication method-2
    Elapsed time is 0.032850 seconds.
    3. For-loop method
    Elapsed time is 1.775382 seconds.
    4. Direct-multiplication method
    Elapsed time is 13.764870 seconds.
    5. Bsxfun method
    Elapsed time is 0.044931 seconds.
    
    看起来,
    sum-multiply
    方法是最好的方法,尽管
    bsxfun
    方法似乎正在追赶它们,因为
    M
    从100增加到1000

    接下来,仅使用
    sum multiply
    bsxfun
    方法测试更高的基准规模。尺寸是-

    M_arr = [1000 2000 5000 10000 20000 50000];
    
    结果是-

    *** Benchmarking sizes are M =1000 and N =4
    1. Sum-multiplication method-1
    Elapsed time is 0.030390 seconds.
    2. Sum-multiplication method-2
    Elapsed time is 0.032334 seconds.
    5. Bsxfun method
    Elapsed time is 0.047377 seconds.
    
    *** Benchmarking sizes are M =2000 and N =4
    1. Sum-multiplication method-1
    Elapsed time is 0.040111 seconds.
    2. Sum-multiplication method-2
    Elapsed time is 0.045132 seconds.
    5. Bsxfun method
    Elapsed time is 0.060762 seconds.
    
    *** Benchmarking sizes are M =5000 and N =4
    1. Sum-multiplication method-1
    Elapsed time is 0.099986 seconds.
    2. Sum-multiplication method-2
    Elapsed time is 0.103213 seconds.
    5. Bsxfun method
    Elapsed time is 0.117650 seconds.
    
    *** Benchmarking sizes are M =10000 and N =4
    1. Sum-multiplication method-1
    Elapsed time is 0.375604 seconds.
    2. Sum-multiplication method-2
    Elapsed time is 0.273726 seconds.
    5. Bsxfun method
    Elapsed time is 0.226791 seconds.
    
    *** Benchmarking sizes are M =20000 and N =4
    1. Sum-multiplication method-1
    Elapsed time is 1.906839 seconds.
    2. Sum-multiplication method-2
    Elapsed time is 1.849166 seconds.
    5. Bsxfun method
    Elapsed time is 1.344905 seconds.
    
    *** Benchmarking sizes are M =50000 and N =4
    1. Sum-multiplication method-1
    Elapsed time is 5.159177 seconds.
    2. Sum-multiplication method-2
    Elapsed time is 5.081211 seconds.
    5. Bsxfun method
    Elapsed time is 3.866018 seconds.
    
    备用基准测试代码(带'timeit)

    绘图

    num_runs = 1000;
    M_arr = [1000 2000 5000 10000 20000 50000 100000 200000 500000 1000000];
    N = 4;
    
    timeall = zeros(5,numel(M_arr));
    for k2 = 1:numel(M_arr)
        M = M_arr(k2);
    
        A = rand(M,N);
        B = rand(N,M);
    
        f = @() sum_mult_method1(A,B);
        timeall(1,k2) = timeit(f);
        clear f
    
        f = @() sum_mult_method2(A,B);
        timeall(2,k2) = timeit(f);
        clear f
    
        f = @() bsxfun_method(A,B);
        timeall(5,k2) = timeit(f);
        clear f
    
    end
    
    figure,
    hold on
    plot(M_arr,timeall(1,:),'-ro')
    plot(M_arr,timeall(2,:),'-ko')
    plot(M_arr,timeall(5,:),'-.b')
    legend('sum-method1','sum-method2','bsxfun-method')
    xlabel('M ->')
    ylabel('Time(sec) ->')
    

    最终结论

    *** Benchmarking sizes are M =100 and N =4
    1. Sum-multiplication method-1
    Elapsed time is 0.015242 seconds.
    2. Sum-multiplication method-2
    Elapsed time is 0.015180 seconds.
    3. For-loop method
    Elapsed time is 0.192021 seconds.
    4. Direct-multiplication method
    Elapsed time is 0.065543 seconds.
    5. Bsxfun method
    Elapsed time is 0.054149 seconds.
    
    *** Benchmarking sizes are M =200 and N =4
    1. Sum-multiplication method-1
    Elapsed time is 0.009138 seconds.
    2. Sum-multiplication method-2
    Elapsed time is 0.009428 seconds.
    3. For-loop method
    Elapsed time is 0.435735 seconds.
    4. Direct-multiplication method
    Elapsed time is 0.148908 seconds.
    5. Bsxfun method
    Elapsed time is 0.030946 seconds.
    
    *** Benchmarking sizes are M =500 and N =4
    1. Sum-multiplication method-1
    Elapsed time is 0.033287 seconds.
    2. Sum-multiplication method-2
    Elapsed time is 0.026405 seconds.
    3. For-loop method
    Elapsed time is 0.965260 seconds.
    4. Direct-multiplication method
    Elapsed time is 2.832855 seconds.
    5. Bsxfun method
    Elapsed time is 0.034923 seconds.
    
    *** Benchmarking sizes are M =1000 and N =4
    1. Sum-multiplication method-1
    Elapsed time is 0.026068 seconds.
    2. Sum-multiplication method-2
    Elapsed time is 0.032850 seconds.
    3. For-loop method
    Elapsed time is 1.775382 seconds.
    4. Direct-multiplication method
    Elapsed time is 13.764870 seconds.
    5. Bsxfun method
    Elapsed time is 0.044931 seconds.
    
    似乎直到某个阶段,
    M=5000左右,
    bsxfun
    之后,
    bsxfun
    似乎才占上风

    未来工作


    人们可以研究不同的
    N
    ,并研究这里提到的实现的性能。

    通常,这些数字有多大-N和M?我遇到这个问题的应用是,A是人工神经网络中的权重向量,B是输入向量和另一个权重矩阵之间的差。在这种特殊情况下,M将介于100和1000之间,N=4。因此,我的原始应用程序实际上可能不符合“非常大的矩阵”的条件PIf如果答案中所做的所有基准测试表明,
    diag(A*B)
    未即时优化,请联系MathWorks并提出功能请求。+1了解此方法,以及有关数值差异的非常适当的说明。如果你有基准测试代码,也许你也可以比较我的解决方案并告诉我结果?@LuisMendo:
    bsxfun
    方法比你的5000乘5000案例的解决方案快20%左右。我想这可能是方法中隐含的并行化。谢谢!很高兴知道。这是喜欢
    bsxfun
    的另一个原因。但是,
    *
    sum
    也可以从并行化中获益……或者
    sum(A.*B',2)
    。我有一些关于这方面的基准测试结果,即将发布!希望没问题@太棒了!我刚才问过霍奇勒。顺便说一句,我倾向于避免使用
    和中的
    ,2
    ,因为它比较慢,但现在我不确定这种感觉是否正确。是的,我在提交答案后不久就意识到了这一点,并准备在中编辑它。:)@LuisMendo也为您的实现添加了单独的基准测试结果
    timeit
    是我现在正在研究的,目的是验证这里给出的基准测试结果是否可靠。这些时间看起来很小,很可靠,而且你没有做任何热身。如果您不熟悉可靠的基准测试,但您最近有一个Matlab,我建议您使用
    timeit
    。另外,测试
    bsxfun
    ?@horchler这些数字是OP在评论中建议使用的。现在看看timeit。这些可能是矩阵的大小,但这并不意味着你不能做更多的运行,也许更重要的是,在计时之前预热函数。
    num_runs = 1000;
    M_arr = [1000 2000 5000 10000 20000 50000 100000 200000 500000 1000000];
    N = 4;
    
    timeall = zeros(5,numel(M_arr));
    for k2 = 1:numel(M_arr)
        M = M_arr(k2);
    
        A = rand(M,N);
        B = rand(N,M);
    
        f = @() sum_mult_method1(A,B);
        timeall(1,k2) = timeit(f);
        clear f
    
        f = @() sum_mult_method2(A,B);
        timeall(2,k2) = timeit(f);
        clear f
    
        f = @() bsxfun_method(A,B);
        timeall(5,k2) = timeit(f);
        clear f
    
    end
    
    figure,
    hold on
    plot(M_arr,timeall(1,:),'-ro')
    plot(M_arr,timeall(2,:),'-ko')
    plot(M_arr,timeall(5,:),'-.b')
    legend('sum-method1','sum-method2','bsxfun-method')
    xlabel('M ->')
    ylabel('Time(sec) ->')