matlab/octave-广义矩阵乘法

matlab/octave-广义矩阵乘法,matlab,matrix,octave,matrix-multiplication,Matlab,Matrix,Octave,Matrix Multiplication,我想做一个函数来推广矩阵乘法。基本上,它应该能够执行标准矩阵乘法,但它应该允许通过任何其他函数更改两个二进制运算符的乘积/和 目标是尽可能提高CPU和内存的效率。当然,它的效率总是低于A*B,但运营商的灵活性才是关键 以下是我在阅读后可以想到的几个命令: 方法1-3的问题是,在使用sum()折叠矩阵之前,它们将生成n个矩阵。4更好,因为它在bsxfun中执行sum(),但bsxfun仍然生成n个矩阵(除了它们大部分为空,只包含一个非零值向量作为和,其余的用0填充以满足维度要求) 我想要的是第四种

我想做一个函数来推广矩阵乘法。基本上,它应该能够执行标准矩阵乘法,但它应该允许通过任何其他函数更改两个二进制运算符的乘积/和

目标是尽可能提高CPU和内存的效率。当然,它的效率总是低于A*B,但运营商的灵活性才是关键

以下是我在阅读后可以想到的几个命令:

方法1-3的问题是,在使用sum()折叠矩阵之前,它们将生成n个矩阵。4更好,因为它在bsxfun中执行sum(),但bsxfun仍然生成n个矩阵(除了它们大部分为空,只包含一个非零值向量作为和,其余的用0填充以满足维度要求)

我想要的是第四种方法,但没有无用的0来节省内存


有什么想法吗?

在检查了一些处理函数(如bsxfun)之后,似乎不可能使用这些函数进行直接矩阵乘法(我所说的直接是指临时乘积不存储在内存中,而是尽快求和,然后处理其他和积),因为它们有固定大小的输出(或者与输入相同,或者使用bsxfun singleton展开——两个输入维度的笛卡尔积)。但是,可以稍微欺骗倍频程(这与检查输出维度的MatLab不起作用):

但是,不要使用它们,因为输出的值不可靠(倍频程可能会损坏甚至删除它们并返回0!)

所以现在我只是实现一个半矢量化版本,下面是我的函数:

function C = genmtimes(A, B, outop, inop)
% C = genmtimes(A, B, inop, outop)
% Generalized matrix multiplication between A and B. By default, standard sum-of-products matrix multiplication is operated, but you can change the two operators (inop being the element-wise product and outop the sum).
% Speed note: about 100-200x slower than A*A' and about 3x slower when A is sparse, so use this function only if you want to use a different set of inop/outop than the standard matrix multiplication.

if ~exist('inop', 'var')
    inop = @times;
end

if ~exist('outop', 'var')
    outop = @sum;
end

[n, m] = size(A);
[m2, o] = size(B);

if m2 ~= m
    error('nonconformant arguments (op1 is %ix%i, op2 is %ix%i)\n', n, m, m2, o);
end


C = [];
if issparse(A) || issparse(B)
    C = sparse(o,n);
else
    C = zeros(o,n);
end

A = A';
for i=1:n
    C(:,i) = outop(bsxfun(inop, A(:,i), B))';
end
C = C';

end
使用稀疏矩阵和普通矩阵进行测试:使用稀疏矩阵(慢3倍)比使用普通矩阵(慢约100倍)的性能差距小得多

我认为这比bsxfun实现慢,但至少不会溢出内存:

A = randi(10, 1000);
C = genmtimes(A, A');

如果有人能提供更好的,我仍在寻找更好的替代方案!

不必深入讨论细节,有些工具,如和,是快速通用矩阵和标量运算例程。您可以查看它们的代码,并根据需要调整它们。 它很可能比matlab的bsxfun更快。

为什么不利用它接受任意函数的能力呢

C = shiftdim(feval(f, (bsxfun(g, A.', permute(B,[1 3 2])))), 1);
这里

  • f
    外部函数(对应于矩阵乘法情况下的求和)。它应该接受任意大小的3D数组
    m
    x
    n
    x
    p
    ,并沿其列操作以返回
    1
    x
    m
    x
    p
    数组
  • g
    内部函数(对应于矩阵乘法情况下的乘积)。根据
    bsxfun
    ,它应接受两个大小相同的列向量或一个列向量和一个标量作为输入,并返回与输入大小相同的列向量作为输出
这在Matlab中工作。我还没有测试过倍频程


示例1:矩阵乘法:

>> f = @sum;   %// outer function: sum
>> g = @times; %// inner function: product
>> A = [1 2 3; 4 5 6];
>> B = [10 11; -12 -13; 14 15];
>> C = shiftdim(feval(f, (bsxfun(g, A.', permute(B,[1 3 2])))), 1)
C =
    28    30
    64    69
        (tic/toc times in seconds)
      (tested in R2014a on Windows 8)

    size      mtimes       my_mtimes 
    ____    __________     _________
     400     0.0026398       0.20282
     600      0.012039       0.68471
     800      0.014571        1.6922
    1000      0.026645        3.5107
    2000       0.20204         28.76
    4000        1.5578        221.51
检查:

>> A*B
ans =
    28    30
    64    69

<强>例2 < /强>:考虑上述两个矩阵的

>> f = @(x,y) sum(abs(x));     %// outer function: sum of absolute values
>> g = @(x,y) max(x./y, y./x); %// inner function: "symmetric" ratio
>> C = shiftdim(feval(f, (bsxfun(g, A.', permute(B,[1 3 2])))), 1)
C =
   14.8333   16.1538
    5.2500    5.6346
检查:手动计算
C(1,2)


这是您发布的解决方案的一个稍微完善的版本,有一些小的改进

我们检查行数是否多于列数,或者反过来检查行数是否多于列数,然后通过选择行与矩阵相乘或矩阵与列相乘(这样做的循环迭代次数最少)进行相应的乘法

注意:即使行少于列,这可能并不总是最好的策略(按行而不是按列);事实上,MATLAB数组存储在内存中使得按列切片更有效,因为元素是连续存储的。而访问行涉及按行遍历元素(这对缓存不友好——想想看)

除此之外,代码应该处理双精度/单精度、实精度/复杂精度、全精度/稀疏精度(以及不可能组合的错误)。它还考虑空矩阵和零维

function C = my_mtimes(A, B, outFcn, inFcn)
    % default arguments
    if nargin < 4, inFcn = @times; end
    if nargin < 3, outFcn = @sum; end

    % check valid input
    assert(ismatrix(A) && ismatrix(B), 'Inputs must be 2D matrices.');
    assert(isequal(size(A,2),size(B,1)),'Inner matrix dimensions must agree.');
    assert(isa(inFcn,'function_handle') && isa(outFcn,'function_handle'), ...
        'Expecting function handles.')

    % preallocate output matrix
    M = size(A,1);
    N = size(B,2);
    if issparse(A)
        args = {'like',A};
    elseif issparse(B)
        args = {'like',B};
    else
        args = {superiorfloat(A,B)};
    end
    C = zeros(M,N, args{:});

    % compute matrix multiplication
    % http://en.wikipedia.org/wiki/Matrix_multiplication#Inner_product
    if M < N
        % concatenation of products of row vectors with matrices
        % A*B = [a_1*B ; a_2*B ; ... ; a_m*B]
        for m=1:M
            %C(m,:) = A(m,:) * B;
            %C(m,:) = sum(bsxfun(@times, A(m,:)', B), 1);
            C(m,:) = outFcn(bsxfun(inFcn, A(m,:)', B), 1);
        end
    else
        % concatenation of products of matrices with column vectors
        % A*B = [A*b_1 , A*b_2 , ... , A*b_n]
        for n=1:N
            %C(:,n) = A * B(:,n);
            %C(:,n) = sum(bsxfun(@times, A, B(:,n)'), 2);
            C(:,n) = outFcn(bsxfun(inFcn, A, B(:,n)'), 2);
        end
    end
end

以下是测试代码:

sz = [10:10:100 200:200:1000 2000 4000];
t = zeros(numel(sz),2);
for i=1:numel(sz)
    n = sz(i); disp(n)
    A = rand(n,n);
    B = rand(n,n);

    tic
    C = A*B;
    t(i,1) = toc;
    tic
    D = my_mtimes(A,B);
    t(i,2) = toc;

    assert(norm(C-D) < 1e-6)
    clear A B C D
end

semilogy(sz, t*1000, '.-')
legend({'mtimes','my_mtimes'}, 'Interpreter','none', 'Location','NorthWest')
xlabel('Size N'), ylabel('Time [msec]'), title('Matrix Multiplication')
axis tight
另一种方法(使用三重循环):


下一步该怎么办? 如果你想挤出更多的性能,你必须移动到C/C++MEX文件以减少解释的MATLAB代码的开销。你仍然可以通过从MEX文件调用优化的BLAS/LAPACK例程来利用它们(参见示例).MATLAB附带了一个库,坦白地说,在英特尔处理器上进行线性代数计算时,你无法击败它


其他人已经提到了一些关于文件交换的提交,它们将通用矩阵例程实现为MEX文件(请参阅的答案)。如果您将它们与优化的BLAS库相链接,则这些功能尤其有效。

为什么不尝试使用稀疏矩阵来节省内存分配?您可能可以让它发挥作用。与bsxfun类似,但对于稀疏矩阵,我假设它在后台也能保持相当低的内存使用率。已经完成,的确,第四种方法应该能够从稀疏性中获益,但不幸的是,它不能与倍频程一起工作,因为它的bsxfun运算符不是稀疏友好的,所以所有内容都将存储在内存中。您的第三个和第四个示例不起作用。您的第一个示例不适用于MATLAB R2010b及更早版本。我的另一个问题是,矩阵有多大你正在处理的是你如此关心记忆的问题。@RodyOldenhuis:谢谢你的反馈,的确,第三和第四个方法只在八度上工作。还有一个原因需要找到替代方法,因为第四个方法的问题正是我想要解决的问题:输出维度不正确+
function C = my_mtimes(A, B, outFcn, inFcn)
    % default arguments
    if nargin < 4, inFcn = @times; end
    if nargin < 3, outFcn = @sum; end

    % check valid input
    assert(ismatrix(A) && ismatrix(B), 'Inputs must be 2D matrices.');
    assert(isequal(size(A,2),size(B,1)),'Inner matrix dimensions must agree.');
    assert(isa(inFcn,'function_handle') && isa(outFcn,'function_handle'), ...
        'Expecting function handles.')

    % preallocate output matrix
    M = size(A,1);
    N = size(B,2);
    if issparse(A)
        args = {'like',A};
    elseif issparse(B)
        args = {'like',B};
    else
        args = {superiorfloat(A,B)};
    end
    C = zeros(M,N, args{:});

    % compute matrix multiplication
    % http://en.wikipedia.org/wiki/Matrix_multiplication#Inner_product
    if M < N
        % concatenation of products of row vectors with matrices
        % A*B = [a_1*B ; a_2*B ; ... ; a_m*B]
        for m=1:M
            %C(m,:) = A(m,:) * B;
            %C(m,:) = sum(bsxfun(@times, A(m,:)', B), 1);
            C(m,:) = outFcn(bsxfun(inFcn, A(m,:)', B), 1);
        end
    else
        % concatenation of products of matrices with column vectors
        % A*B = [A*b_1 , A*b_2 , ... , A*b_n]
        for n=1:N
            %C(:,n) = A * B(:,n);
            %C(:,n) = sum(bsxfun(@times, A, B(:,n)'), 2);
            C(:,n) = outFcn(bsxfun(inFcn, A, B(:,n)'), 2);
        end
    end
end
        (tic/toc times in seconds)
      (tested in R2014a on Windows 8)

    size      mtimes       my_mtimes 
    ____    __________     _________
     400     0.0026398       0.20282
     600      0.012039       0.68471
     800      0.014571        1.6922
    1000      0.026645        3.5107
    2000       0.20204         28.76
    4000        1.5578        221.51
sz = [10:10:100 200:200:1000 2000 4000];
t = zeros(numel(sz),2);
for i=1:numel(sz)
    n = sz(i); disp(n)
    A = rand(n,n);
    B = rand(n,n);

    tic
    C = A*B;
    t(i,1) = toc;
    tic
    D = my_mtimes(A,B);
    t(i,2) = toc;

    assert(norm(C-D) < 1e-6)
    clear A B C D
end

semilogy(sz, t*1000, '.-')
legend({'mtimes','my_mtimes'}, 'Interpreter','none', 'Location','NorthWest')
xlabel('Size N'), ylabel('Time [msec]'), title('Matrix Multiplication')
axis tight
C = zeros(M,N, args{:});
for m=1:M
    for n=1:N
        %C(m,n) = A(m,:) * B(:,n);
        %C(m,n) = sum(bsxfun(@times, A(m,:)', B(:,n)));
        C(m,n) = outFcn(bsxfun(inFcn, A(m,:)', B(:,n)));
    end
end
C = zeros(M,N, args{:});
P = size(A,2); % = size(B,1);
for m=1:M
    for n=1:N
        for p=1:P
            %C(m,n) = C(m,n) + A(m,p)*B(p,n);
            %C(m,n) = plus(C(m,n), times(A(m,p),B(p,n)));
            C(m,n) = outFcn([C(m,n) inFcn(A(m,p),B(p,n))]);
        end
    end
end