Matlab 批量矩阵乘法的优化
我大致得出以下结论:Matlab 批量矩阵乘法的优化,matlab,matrix-multiplication,Matlab,Matrix Multiplication,我大致得出以下结论: A = rand(10, 20, 30); B = rand(10, 30, 40); 我想得到一个大小为(10,20,40)的矩阵C,目前正在使用for循环: for i = 1:10 C(i, :, :) = squeeze(A(i, :, :)) * squeeze(B(i, :, :)); end 我试过做C=bsxfun(@mtimes,A,B),但这不起作用 优化这一点的最佳方法是什么?我不是在寻找可读性很好的代码,只是在寻找我能得到的最优化的代码
A = rand(10, 20, 30);
B = rand(10, 30, 40);
我想得到一个大小为(10,20,40)
的矩阵C
,目前正在使用for循环:
for i = 1:10
C(i, :, :) = squeeze(A(i, :, :)) * squeeze(B(i, :, :));
end
我试过做C=bsxfun(@mtimes,A,B)代码>,但这不起作用
优化这一点的最佳方法是什么?我不是在寻找可读性很好的代码,只是在寻找我能得到的最优化的代码
谢谢。您可以更改指定的维度以优化内存访问。毕竟,矩阵与一维数组一样长。对它们进行不同的切片可以(并且确实)访问相邻的值,而不是到处跳跃。您的代码是:
A = rand(20, 30, 10);
B = rand(30, 40, 10);
C = zeros(20, 40, 10);
for i = 1:10
C(:, :, i) = A(:, :, i) * B(:, :, i);
end
注意,您甚至不需要压缩
,因为Matlab会自动删除后面的单例维度,因此由于函数调用较少,您可以减少一些常量us
以下是我使用的代码:
close all; clear; clc;
N = 1000;
N = N+10; % Add a few initial runs to be trimmed off at the end
%% 1st dimension
% Preallocate C
A = rand(10, 20, 30); B = rand(10, 30, 40); C = zeros(10, 20, 40);
t1 = zeros(1,N); % Preallocate timing results vector
for j = 1:N % Do the multiplication N times
tic
for i = 1:10
C(i, :, :) = squeeze(A(i, :, :)) * squeeze(B(i, :, :));
end
t1(j) = toc;
end
%% 2nd dimension
A = rand(20, 10, 30); B = rand(30, 10, 40); C = zeros(20, 10, 40);
t2 = zeros(1,N);
for j = 1:N
tic
for i = 1:10
C(:, i, :) = squeeze(A(:, i, :)) * squeeze(B(:, i, :));
end
t2(j) = toc;
end
%% 3rd dimension
A = rand(20, 30, 10); B = rand(30, 40, 10); C = zeros(20, 40, 10);
t3 = zeros(1,N);
for j = 1:N
tic
for i = 1:10
C(:, :, i) = A(:, :, i) * B(:, :, i);
end
t3(j) = toc;
end
%% Plot
% Trim initial runs and convert to microsecconds
t1 = t1(11:end)*1e6; t2 = t2(11:end)*1e6; t3 = t3(11:end)*1e6;
x = 1:N-10;
plot(x,t1,x,t2,x,t3);
grid on;
xlabel('trial number');
ylabel('running time / us');
legend('C(i,:,:)','C(:,i,:)','C(:,:,i)');
title(sprintf('t1 = %.0f, t2 = %.0f, t3 = %.0f us',median(t1),median(t2),median(t3)));
在我开始之前,重要的是要认识到矩阵乘法是一个非常昂贵的过程。它的渐近复杂性是O(n^3)(O(n^2.8)与strassens)。这意味着,虽然你可能不认为你在做很多计算,但实际上有数十亿的计算发生,你甚至都不知道。正因为如此,由于计算的数量太多,你真正能做的事情是有限的
如果您不希望用于循环,有两种方法可以在MATLAB中执行批量矩阵乘法
第一个是名为的函数。此函数在编译后使用稀疏矩阵对过程进行矢量化。但是,矩阵必须在前2维中。在代码中,此操作将是
A = rand(10, 20, 30);
B = rand(10, 30, 40);
A = permute(A,[2 3 1]); % Change the dimensions as mtimesx always multiplies the first 2 dimensions
B = permute(B,[2 3 1]);
C = mtimesx(A,B);
C = permute(C,[3 1 2]);
这将使您的问题描述的操作通常更快
或者,如果你有一个GPU,你可以用同样的方式
A = rand(10, 20, 30);
B = rand(10, 30, 40);
A = permute(A,[2 3 1]);
B = permute(B,[2 3 1]);
A = gpuArray(A);
B = gpuArray(B);
C = pagefun(@mtimes,A,B);
C = permute(C,[3 1 2]);
此方法将每个问题发送到GPU的一页上,如果使用单精度,此方法通常比mtimesx快得多
我修改了@MarcinKonowalczyk脚本来运行所有示例。如您所见,在本例中,mtimesx的性能最好,与其他方法相比有了相当大的改进
此外,此图使用1000次矩阵乘法,而不是10次,这里我们开始看到GPU相对于CPU的优势
您当前的方法看起来是正确的。慢吗?在我的电脑中需要0.0003s
。请注意,你需要压缩每个矩阵,而不是乘法的结果。在我的机器上,它一点也不快,我认为for
循环是罪魁祸首。请注意,在我的应用程序中,A
和B
要大得多。您是否正确地预先分配了C
?是的,使用0
@galah92当然速度将取决于大小,这一点很清楚。我显示的时间是您发布的整个循环示例的时间。
close all; clear;
N = 1000;
N = N+10; % Add a few initial runs to be trimmed off at the end
%% 1st dimension
% Preallocate C
num_problems = 10;
outer_left = 20;
inner = 30;
outer_right = 40;
A = rand(num_problems, outer_left, inner); B = rand(num_problems, inner, outer_right); C = zeros(num_problems, outer_left, outer_right);
t1 = zeros(1,N); % Preallocate timing results vector
for j = 1:N % Do the multiplication N times
tic
for i = 1:num_problems
C(i, :, :) = squeeze(A(i, :, :)) * squeeze(B(i, :, :));
end
t1(j) = toc;
end
%% 2nd dimension
A = permute(A,[2 1 3]); B = permute(B,[2 1 3]); C = permute(C,[2 1 3]);
t2 = zeros(1,N);
for j = 1:N
tic
for i = 1:num_problems
C(:, i, :) = squeeze(A(:, i, :)) * squeeze(B(:, i, :));
end
t2(j) = toc;
end
%% 3rd dimension
A = permute(A,[1 3 2]); B = permute(B,[1 3 2]); C = permute(C,[1 3 2]);
t3 = zeros(1,N);
for j = 1:N
tic
for i = 1:num_problems
C(:, :, i) = A(:, :, i) * B(:, :, i);
end
t3(j) = toc;
end
t4 = zeros(1,N);
for ii = 1:N
tic
C = mtimesx(A,B);
t4(ii) = toc;
end
A = gpuArray(A);
B = gpuArray(B);
t5 = zeros(1,N);
for ii = 1:N
tic
C = pagefun(@mtimes,A,B);
t5(ii) = toc;
end
%% Plot
% Trim initial runs and convert to microsecconds
t1 = t1(11:end)*1e6; t2 = t2(11:end)*1e6; t3 = t3(11:end)*1e6;
t4 = t4(11:end)*1e6; t5 = t5(11:end)*1e6;
x = 1:N-10;
plot(x,t1,x,t2,x,t3,x,t4,x,t5);
grid on;
xlabel('trial number');
ylabel('running time / us');
legend('C(i,:,:)','C(:,i,:)','C(:,:,i)','mtimesx','pagefun');
title(sprintf('t1 = %.0f, t2 = %.0f, t3 = %.0f, t4 = %.0f, t5 = %.0f us',median(t1),median(t2),median(t3),median(t4),median(t5)));