Performance 变尺寸矩阵乘法的优化

Performance 变尺寸矩阵乘法的优化,performance,julia,matrix-multiplication,Performance,Julia,Matrix Multiplication,假设我有以下数据生成过程 using Random using StatsBase m_1 = [1.0 2.0] m_2 = [1.0 2.0; 3.0 4.0] DD = [] y = zeros(2,200) for i in 1:100 rand!(m_1) rand!(m_2) push!(DD, m_1) push!(DD, m_2) end idxs = sample(1:200,10) for i in id

假设我有以下数据生成过程

using Random
using StatsBase

m_1    = [1.0 2.0]
m_2    = [1.0 2.0; 3.0 4.0]
DD     = []
y      = zeros(2,200)

for i in 1:100
    rand!(m_1)
    rand!(m_2)
    push!(DD, m_1)
    push!(DD, m_2)
end

idxs   = sample(1:200,10)
for i in idxs
    DD[i] = DD[1]
end
假设给定数据,我有以下函数


function test(y, DD, n)
    v_1   = [1 2]
    v_2   = [3 4]
    for j in 1:n
        for i in 1:size(DD,1)
            if size(DD[i],1) == 1
                y[1:size(DD[i],1),i] .= (v_1 * DD[i]')[1]
            else
                y[1:size(DD[i],1),i] = (v_2 * DD[i]')'
            end
        end
    end
end
我正在努力优化
测试的速度。特别是,内存分配随着I的增加而增加
n
。然而,我并没有真正分配任何新的东西

数据生成过程捕获了这样一个事实:我事先不确定
DD[I]
的大小。也就是说,我第一次调用
test
DD[1]
可以是一个2x2矩阵。第二次调用
test
DD[1]
可能是一个1x2矩阵。我认为这可能是内存分配问题的一部分:朱莉娅事先不知道大小


我完全卡住了。我试过
@inbounds
,但没用。有什么方法可以改进这一点吗?

检查性能的一个重要方面是Julia能够理解类型。您可以通过运行
@code\u warntypetest(y,DD,1)
来检查这一点,输出将明确
DD
属于
Any[]
类型(因为您是这样声明的)。使用
Any
可能会导致相当大的性能损失,因此声明
DD=Matrix{Float64}[]
将测试时间缩短到三分之一

我不确定这个示例与您想要编写的实际代码有多接近,但是在这个特殊情况下,
size(DD[I],1)==1
分支可以被调用
linearagebra.dot
替换:

y[1:size(DD[i],1),i] .= dot(v_1, DD[i])
这使我的时间又减少了50%。最后,通过使用
mul,您可以再挤出一点点执行其他乘法到位:

mul!(view(y, 1:size(DD[i],1),i:i), DD[i], v_2')
完整示例:

using Random
using LinearAlgebra

DD = [rand(i,2) for _ in 1:100 for i in 1:2]
y  = zeros(2,200)
shuffle!(DD)

function test(y, DD, n)
    v_1   = [1 2]
    v_2   = [3 4]'
    for j in 1:n
        for i in 1:size(DD,1)
            if size(DD[i],1) == 1
                y[1:size(DD[i],1),i] .= dot(v_1, DD[i])
            else
                mul!(view(y, 1:size(DD[i],1),i:i), DD[i], v_2)
            end
        end
    end
end

我真的不清楚你在这个代码中的意图是什么。澄清可能会使回答更容易。此外,调用
rand可能不会做你想做的事。因为它们在适当的位置修改了参数,所以您不会创建新的数组,而是在适当的位置不断更改旧数组,即使它们是
DD
的一部分。这意味着
DD
将由相同的两个数组组成
m1
m2
重复和混合。做
display(DD)
你就会明白我的意思了。