Python 基于另一个numpy阵列的argmax切片一个numpy阵列

Python 基于另一个numpy阵列的argmax切片一个numpy阵列,python,numpy,Python,Numpy,我有两个数组,如下所示: import numpy as np np.random.seed(42) a = (np.random.uniform(size=[2, 5, 3]) * 100).astype(int) b = (np.random.uniform(size=[2, 5, 3]) * 100).astype(int) array([[[60, 17, 6], [94, 96, 80], [30, 9, 68], [44, 12

我有两个数组,如下所示:

import numpy as np

np.random.seed(42)
a = (np.random.uniform(size=[2, 5, 3]) * 100).astype(int)
b = (np.random.uniform(size=[2, 5, 3]) * 100).astype(int)
array([[[60, 17,  6],
        [94, 96, 80],
        [30,  9, 68],
        [44, 12, 49],
        [ 3, 90, 25]],

       [[66, 31, 52],
        [54, 18, 96],
        [77, 93, 89],
        [59, 92,  8],
        [19,  4, 32]]])
数组
a的输出

array([[[37, 95, 73],
        [59, 15, 15],
        [ 5, 86, 60],
        [70,  2, 96],
        [83, 21, 18]],

       [[18, 30, 52],
        [43, 29, 61],
        [13, 29, 36],
        [45, 78, 19],
        [51, 59,  4]]])
数组
b
的输出如下:

import numpy as np

np.random.seed(42)
a = (np.random.uniform(size=[2, 5, 3]) * 100).astype(int)
b = (np.random.uniform(size=[2, 5, 3]) * 100).astype(int)
array([[[60, 17,  6],
        [94, 96, 80],
        [30,  9, 68],
        [44, 12, 49],
        [ 3, 90, 25]],

       [[66, 31, 52],
        [54, 18, 96],
        [77, 93, 89],
        [59, 92,  8],
        [19,  4, 32]]])
现在,我可以使用以下代码获取数组
a
argmax

idx = np.argmax(a, axis=0)
print(idx)
输出:

array([[0, 0, 0],
       [0, 1, 1],
       [1, 0, 0],
       [0, 1, 0],
       [0, 1, 0]], dtype=int64)
是否可以使用array
a
的argmax输出对array
b
进行切片,以便获得以下输出:

array([[60, 17,  6],
       [94, 18, 96],
       [77,  9, 68],
       [44, 92, 49],
       [ 3, 4, 25]])

我尝试了不同的方法,但没有成功。请提供帮助。

使用numpy高级索引:

将numpy导入为np
np.随机种子(42)
a=(np.random.uniform(size=[2,5,3])*100).astype(int)
b=(np.random.uniform(size=[2,5,3])*100).astype(int)
idx=np.argmax(a,轴=0)
_,m,n=a.形状
b[idx,np.arange(m)[:无],np.arange(n)]
数组([[60,17,6],
[94, 18, 96],
[77,  9, 68],
[44, 92, 49],
[ 3,  4, 25]])

使用numpy高级索引:

将numpy导入为np
np.随机种子(42)
a=(np.random.uniform(size=[2,5,3])*100).astype(int)
b=(np.random.uniform(size=[2,5,3])*100).astype(int)
idx=np.argmax(a,轴=0)
_,m,n=a.形状
b[idx,np.arange(m)[:无],np.arange(n)]
数组([[60,17,6],
[94, 18, 96],
[77,  9, 68],
[44, 92, 49],
[ 3,  4, 25]])

非常确定有一些重复,但你可以做
np。沿着轴(b,idx[None],axis=0)[0]
。非常确定有一些重复,但你可以做
np。沿着轴(b,idx[None],axis=0)[0]