Arrays Julia`mapslices`当输出维度超过输入维度时
我试图将函数Arrays Julia`mapslices`当输出维度超过输入维度时,arrays,julia,Arrays,Julia,我试图将函数f(x::Array{Float64,1})->Array{Float64,2}应用于Julia中mxn数组中的每一行(如果相关的话,我使用v1.1) 现在,我希望我可以简单地应用mapslices,如下所示: # toy example of f f = (x -> randn(length(x), length(x))) A = randn(100, 50) # intent: apply f to every row in A and collect the resul
f(x::Array{Float64,1})->Array{Float64,2}
应用于Julia中mxn
数组中的每一行(如果相关的话,我使用v1.1)
现在,我希望我可以简单地应用mapslices
,如下所示:
# toy example of f
f = (x -> randn(length(x), length(x)))
A = randn(100, 50)
# intent: apply f to every row in A and collect the result into a 100 x 50 x 50 matrix.
result = mapslices(f, A, dims=2)
不幸的是,mapslices
失败,并显示以下消息:
错误:维度不匹配(“试图将2个元素分配给1个目标”)
对于这种情况是否有类似的mapslices
?我知道我可以通过for
循环来实现这一点,但我希望有更简单的方法
更新:显然,一种方法是将A
嵌入到三维阵列中:
result = mapslices(f, reshape(A, (size(A)..., 1)), dims=[2, 3])
mapslices
不是很灵活。以下是你所要求的一种变体,不用它即可完成:
julia> f(x::AbstractVector) = x .* x'; # vector -> matrix
julia> A = randn(5, 7);
julia> f(A[1,:]) |> size
(7, 7)
julia> reduce(hcat, map(f, eachrow(A))) |> size
(7, 35)
julia> B = reshape(reduce(hcat, map(f, eachrow(A))), (7,7,5));
julia> B[:,:,3] ≈ f(A[3,:])
true
这里reduce(hcat,…)
将f
中的矩阵沿第二个方向进行组合;我们需要对进行重塑
以再次分离A
的第一个索引,该索引已成为B
的最后一个索引——这5行中的哪一行构成了此切片
还有许多一揽子解决方案:
julia> using JuliennedArrays
julia> Slices(A,2) |> size # like eachrow
(5,)
julia> C = Align(map(f, Slices(A,2)), 1,2); # size (7, 7, 5)
julia> C ≈ B
true
julia> D = Align(map(f, Slices(A,2)), 2,3); # size (5, 7, 7)
julia> D[3,:,:] ≈ f(A[3,:])
true
我只尝试了Julia v1.6,但我认为您可以用
每行应用f
(出现在Julia v1.1中):
我认为扩展数组维度为结果腾出空间是最好的方法。它也很好地概括了。很有趣!我喜欢reduce(…)
方法。谢谢
julia> using TensorCast
julia> @cast C2[i,j,k] := f(A[k,:])[i,j]; # := makes a new array
julia> C2 ≈ B
true
julia> @cast D2[k,i,j] := f(A[k,:])[i,j];
julia> D2 ≈ D
true
julia> f(x) = randn(length(x), length(x))
f (generic function with 1 method)
julia> A = rand(3,2)
3×2 Matrix{Float64}:
0.47239 0.179252
0.542389 0.25828
0.0513623 0.630193
julia> f.(eachrow(A))
3-element Vector{Matrix{Float64}}:
[-2.183245554875081 -0.16762649791435957; 0.9124553173227186 0.9148972946316921]
[-0.7322553194397725 -0.5844536492551982; 0.37738478201981623 -0.7056092457600269]
[1.0091890849396576 1.6451194487283958; 0.674221636656597 1.0509408618443663]