Matlab 批量矩阵乘法的优化

Matlab 批量矩阵乘法的优化,matlab,matrix-multiplication,Matlab,Matrix Multiplication,我大致得出以下结论: A = rand(10, 20, 30); B = rand(10, 30, 40); 我想得到一个大小为(10,20,40)的矩阵C,目前正在使用for循环: for i = 1:10 C(i, :, :) = squeeze(A(i, :, :)) * squeeze(B(i, :, :)); end 我试过做C=bsxfun(@mtimes,A,B),但这不起作用 优化这一点的最佳方法是什么?我不是在寻找可读性很好的代码,只是在寻找我能得到的最优化的代码

我大致得出以下结论:

A = rand(10, 20, 30);
B = rand(10, 30, 40);
我想得到一个大小为
(10,20,40)
的矩阵
C
,目前正在使用for循环:

for i = 1:10
    C(i, :, :) = squeeze(A(i, :, :)) * squeeze(B(i, :, :));
end
我试过做
C=bsxfun(@mtimes,A,B),但这不起作用

优化这一点的最佳方法是什么?我不是在寻找可读性很好的代码,只是在寻找我能得到的最优化的代码


谢谢。

您可以更改指定的维度以优化内存访问。毕竟,矩阵与一维数组一样长。对它们进行不同的切片可以(并且确实)访问相邻的值,而不是到处跳跃。您的代码是:

A = rand(20, 30, 10);
B = rand(30, 40, 10);
C = zeros(20, 40, 10);

for i = 1:10
    C(:, :, i) = A(:, :, i) * B(:, :, i);
end
注意,您甚至不需要
压缩
,因为Matlab会自动删除后面的单例维度,因此由于函数调用较少,您可以减少一些常量us

以下是我使用的代码:

close all; clear; clc;

N = 1000;
N = N+10; % Add a few initial runs to be trimmed off at the end

%% 1st dimension
% Preallocate C
A = rand(10, 20, 30); B = rand(10, 30, 40); C = zeros(10, 20, 40);

t1 = zeros(1,N); % Preallocate timing results vector
for j = 1:N % Do the multiplication N times
    tic
    for i = 1:10
        C(i, :, :) = squeeze(A(i, :, :)) * squeeze(B(i, :, :));
    end
    t1(j) = toc;
end

%% 2nd dimension
A = rand(20, 10, 30); B = rand(30, 10, 40); C = zeros(20, 10, 40);

t2 = zeros(1,N);
for j = 1:N
    tic
    for i = 1:10
        C(:, i, :) = squeeze(A(:, i, :)) * squeeze(B(:, i, :));
    end
    t2(j) = toc;
end

%% 3rd dimension
A = rand(20, 30, 10); B = rand(30, 40, 10); C = zeros(20, 40, 10);

t3 = zeros(1,N);
for j = 1:N
    tic
    for i = 1:10
        C(:, :, i) = A(:, :, i) * B(:, :, i);
    end
    t3(j) = toc;
end

%% Plot

% Trim initial runs and convert to microsecconds
t1 = t1(11:end)*1e6; t2 = t2(11:end)*1e6; t3 = t3(11:end)*1e6;

x = 1:N-10;
plot(x,t1,x,t2,x,t3);

grid on;
xlabel('trial number');
ylabel('running time / us');
legend('C(i,:,:)','C(:,i,:)','C(:,:,i)');
title(sprintf('t1 = %.0f, t2 = %.0f, t3 = %.0f us',median(t1),median(t2),median(t3)));

在我开始之前,重要的是要认识到矩阵乘法是一个非常昂贵的过程。它的渐近复杂性是O(n^3)(O(n^2.8)与strassens)。这意味着,虽然你可能不认为你在做很多计算,但实际上有数十亿的计算发生,你甚至都不知道。正因为如此,由于计算的数量太多,你真正能做的事情是有限的

如果您不希望用于循环,有两种方法可以在MATLAB中执行批量矩阵乘法

第一个是名为的函数。此函数在编译后使用稀疏矩阵对过程进行矢量化。但是,矩阵必须在前2维中。在代码中,此操作将是

A = rand(10, 20, 30);
B = rand(10, 30, 40);
A = permute(A,[2 3 1]); % Change the dimensions as mtimesx always multiplies the first 2 dimensions
B = permute(B,[2 3 1]);
C = mtimesx(A,B);
C = permute(C,[3 1 2]);
这将使您的问题描述的操作通常更快

或者,如果你有一个GPU,你可以用同样的方式

A = rand(10, 20, 30);
B = rand(10, 30, 40);
A = permute(A,[2 3 1]);
B = permute(B,[2 3 1]);
A = gpuArray(A);
B = gpuArray(B);
C = pagefun(@mtimes,A,B);
C = permute(C,[3 1 2]);
此方法将每个问题发送到GPU的一页上,如果使用单精度,此方法通常比mtimesx快得多

我修改了@MarcinKonowalczyk脚本来运行所有示例。如您所见,在本例中,mtimesx的性能最好,与其他方法相比有了相当大的改进

此外,此图使用1000次矩阵乘法,而不是10次,这里我们开始看到GPU相对于CPU的优势


您当前的方法看起来是正确的。慢吗?在我的电脑中需要
0.0003s
。请注意,你需要压缩每个矩阵,而不是乘法的结果。在我的机器上,它一点也不快,我认为
for
循环是罪魁祸首。请注意,在我的应用程序中,
A
B
要大得多。您是否正确地预先分配了
C
?是的,使用
0
@galah92当然速度将取决于大小,这一点很清楚。我显示的时间是您发布的整个循环示例的时间。
close all; clear;

N = 1000;
N = N+10; % Add a few initial runs to be trimmed off at the end

%% 1st dimension
% Preallocate C
num_problems = 10;
outer_left = 20;
inner = 30;
outer_right = 40;
A = rand(num_problems, outer_left, inner); B = rand(num_problems, inner, outer_right); C = zeros(num_problems, outer_left, outer_right);

t1 = zeros(1,N); % Preallocate timing results vector
for j = 1:N % Do the multiplication N times
    tic
    for i = 1:num_problems
        C(i, :, :) = squeeze(A(i, :, :)) * squeeze(B(i, :, :));
    end
    t1(j) = toc;
end

%% 2nd dimension
A = permute(A,[2 1 3]); B = permute(B,[2 1 3]); C = permute(C,[2 1 3]);

t2 = zeros(1,N);
for j = 1:N
    tic
    for i = 1:num_problems
        C(:, i, :) = squeeze(A(:, i, :)) * squeeze(B(:, i, :));
    end
    t2(j) = toc;
end

%% 3rd dimension
A = permute(A,[1 3 2]); B = permute(B,[1 3 2]); C = permute(C,[1 3 2]);

t3 = zeros(1,N);
for j = 1:N
    tic
    for i = 1:num_problems
        C(:, :, i) = A(:, :, i) * B(:, :, i);
    end
    t3(j) = toc;
end

t4 = zeros(1,N);
for ii = 1:N
    tic
    C = mtimesx(A,B);
    t4(ii) = toc;
end

A = gpuArray(A);
B = gpuArray(B);
t5 = zeros(1,N);
for ii = 1:N
    tic
    C = pagefun(@mtimes,A,B);
    t5(ii) = toc;
end

%% Plot

% Trim initial runs and convert to microsecconds
t1 = t1(11:end)*1e6; t2 = t2(11:end)*1e6; t3 = t3(11:end)*1e6;
t4 = t4(11:end)*1e6; t5 = t5(11:end)*1e6;

x = 1:N-10;
plot(x,t1,x,t2,x,t3,x,t4,x,t5);

grid on;
xlabel('trial number');
ylabel('running time / us');
legend('C(i,:,:)','C(:,i,:)','C(:,:,i)','mtimesx','pagefun');
title(sprintf('t1 = %.0f, t2 = %.0f, t3 = %.0f, t4 = %.0f, t5 = %.0f us',median(t1),median(t2),median(t3),median(t4),median(t5)));