使用Julia 1.0 findmax等效于numpy.argmax

使用Julia 1.0 findmax等效于numpy.argmax,julia,Julia,在Julia中,我想为每行中的最大值找到矩阵的列索引,结果是向量{Int}。下面是我目前的做法(Samples有7列和10000行): 这是可行的,但感觉相当笨拙和冗长。想知道是否有更好的方法。更新:为了完整起见,我将Matt B.的优秀解决方案添加到测试套件中(并且我还强制f4中的转置生成新的矩阵,而不是惰性视图) 以下是一些不同的方法(您的是基本情况f0): 使用BenchmarkTools我们可以检查每个工具的效率(我设置了x=rand(100200)): 因此,Matt的方法是相当明显的

在Julia中,我想为每行中的最大值找到矩阵的列索引,结果是
向量{Int}
。下面是我目前的做法(
Samples
有7列和10000行):


这是可行的,但感觉相当笨拙和冗长。想知道是否有更好的方法。

更新:为了完整起见,我将Matt B.的优秀解决方案添加到测试套件中(并且我还强制
f4
中的
转置
生成新的矩阵,而不是惰性视图)

以下是一些不同的方法(您的是基本情况
f0
):

使用
BenchmarkTools
我们可以检查每个工具的效率(我设置了
x=rand(100200)
):

因此,Matt的方法是相当明显的赢家,因为它似乎只是我的
f3
的一个语法更清晰的版本(这两个版本可能编译成非常相似的东西,但我认为检查这一点太过分了)

我希望
f4
可能有优势,尽管通过实例化
转置创建了临时,因为它可以对矩阵的列而不是行进行操作(Julia是一种列主语言,因此对列的操作总是更快,因为元素在内存中是同步的)。但这似乎不足以克服暂时性的缺点


注意,如果您需要完整的
CartesianIndex
,即每行中最大值的行和列索引,那么显然适当的解决方案就是
argmax(x,dims=2)

Mapslices函数也是解决此问题的一个很好的选项:

julia> Samples = rand(10000, 7);

julia> res = mapslices(row -> findmax(row)[2], Samples, dims=[2])[:,1];

julia> res[1:10]
10-element Array{Int64,1}:
 3
 1
 3
 5
 4
 4
 1
 4
 5
 3

虽然这比科林上面建议的要慢得多,但对某些人来说可能更具可读性。这基本上与您开始使用的代码完全相同,但使用
mapsicles
而不是列表理解。

更简单:Julia有一个
argmax
函数,Julia 1.1+有一个
每一个
迭代器。因此:

map(argmax, eachrow(x))

简单、易读、快速-它与科林的
f3
f4
在我的快速测试中的性能相匹配。

可以做
last.
而不是
getindex.
如果idk更清晰的话。@LyndonWhite我在v1.1的
lastindex(::CartesianIndex{2})
上得到一个方法错误。诚然,我对此感到惊讶
CartesianIndex
是一个iterable right,因此它应该有一个
lastindex
方法…@LyndonWhite,我刚刚尝试迭代一个
CartesianIndex
,结果是:
错误:CartesianIndex故意不支持迭代。使用I而不是I…,或者使用Tuple(I)…
,因此显然当前行为有很好的理由。我确实喜欢你的
最后一个
想法,但是如果你也必须在其中添加
元组,那么我认为它在语法上不会更清晰:-)迭代是有意的,而
lastindex
我不认为是
idx[end]
作为一件事是有意义的。你想提出一个问题吗?@LyndonWhite我确实从GitHub问题页面开始,但是被默认的横幅推迟了:
如果你有问题或者不确定你正在经历的行为是否是一个bug,请搜索或发布到我们的讨论网站
。我现在打开了一个问题对不起,我更新了我的答案大约7000次。我想我现在已经完成了:-如果我的答案是有帮助的和完整的,那么请考虑投票和标记回答的问题(点击我的响应旁边的刻度标记)。我可能会与<代码> MMAPARS/<代码>:<代码> MAPPLATE(ROO-> FIDEMAX(行)[2 ],样本,diMS=(2)] [:,1 ] < /代码>或:<代码> AgMax。(每行(样本))
Broadcast在这里效率不高,因为
eachrow
只是一个迭代器,所以Broadcast当前在使用它之前会将它收集到一个数组中-这是我希望在将来改进的东西,但现在它的性能会更差。另外,如果您仍停留在Julia 1.0上,请注意,我们在这里使用的
eachrow
定义非常简单!您可以临时将其添加到项目中:
julia> @btime f0($x);
  76.846 μs (13 allocations: 4.64 KiB)

julia> @btime f1($x);
  76.594 μs (11 allocations: 3.75 KiB)

julia> @btime f2($x);
  53.433 μs (103 allocations: 177.48 KiB)

julia> @btime f3($x);
  43.477 μs (3 allocations: 944 bytes)

julia> @btime f4($x);
  73.435 μs (6 allocations: 157.27 KiB)

julia> @btime f5($x);
  43.900 μs (4 allocations: 960 bytes)
julia> Samples = rand(10000, 7);

julia> res = mapslices(row -> findmax(row)[2], Samples, dims=[2])[:,1];

julia> res[1:10]
10-element Array{Int64,1}:
 3
 1
 3
 5
 4
 4
 1
 4
 5
 3
map(argmax, eachrow(x))