Python 用(n-1)d数组索引n维数组

Python 用(n-1)d数组索引n维数组,python,numpy,Python,Numpy,使用(n-1)维数组沿给定维度访问n维数组(如虚拟示例中所示)最优雅的方式是什么 a = np.random.random_sample((3,4,4)) b = np.random.random_sample((3,4,4)) idx = np.argmax(a, axis=0) 我现在如何使用idx a访问以获得a中的最大值,就好像我使用了a.max(axis=0)?或者如何检索b中idx指定的值 我曾考虑过使用np.meshgrid,但我认为这是一种过度使用。请注意,尺寸轴可以是任何有用

使用(n-1)维数组沿给定维度访问n维数组(如虚拟示例中所示)最优雅的方式是什么

a = np.random.random_sample((3,4,4))
b = np.random.random_sample((3,4,4))
idx = np.argmax(a, axis=0)
我现在如何使用
idx a
访问以获得
a
中的最大值,就好像我使用了
a.max(axis=0)
?或者如何检索
b
idx
指定的值

我曾考虑过使用
np.meshgrid
,但我认为这是一种过度使用。请注意,尺寸
轴可以是任何有用的轴(0,1,2),并且事先不知道。有什么优雅的方法可以做到这一点吗?

利用-


对于一般情况:

def argmax_to_max(arr, argmax, axis):
    """argmax_to_max(arr, arr.argmax(axis), axis) == arr.max(axis)"""
    new_shape = list(arr.shape)
    del new_shape[axis]

    grid = np.ogrid[tuple(map(slice, new_shape))]
    grid.insert(axis, argmax)

    return arr[tuple(grid)]
不幸的是,这比自然操作要尴尬得多

对于使用
(n-1)dim
数组对
n dim
数组进行索引,我们可以稍微简化它,为所有轴提供索引网格,如下所示-

def all_idx(idx, axis):
    grid = np.ogrid[tuple(map(slice, idx.shape))]
    grid.insert(axis, idx)
    return tuple(grid)
因此,使用它来索引输入数组-

axis = 0
a_max_values = a[all_idx(idx, axis=axis)]
b_max_values = b[all_idx(idx, axis=axis)]

all_idx
非常优雅。爱你的帖子,@Divakar@我现在明白了!因此,如果输入
a
是一个一维数组,那么对于tuple,我们将得到一个标量,它复制
.max()
行为。但是如果没有元组,我们将得到一个只有一个元素的数组。所以,我认为这可能是将其保留为元组的一个原因。
all_idx
很好。我没有意识到使用argmax输出的形状而不是原始数组的形状会简化事情。对于元组而不是列表,高级索引语义使用元组更清晰。在这种情况下,由于在某些条件下将列表转换为元组的向后兼容性处理(文档记录不完全正确),列表的行为恰好相同。在NumPy索引中,当列表被视为元组,当列表被视为数组时,这可能会令人惊讶,因此我更喜欢显式创建元组。
idx
是标量的情况我甚至没有想到,如果我们保留一个列表而不是元组,则不会触发向后兼容性处理,并且会出现错误的结果。如图所示,元组的行为更一致,更容易预测。
take\u沿轴添加了
,使之更容易。和一个
put
axis = 0
a_max_values = a[all_idx(idx, axis=axis)]
b_max_values = b[all_idx(idx, axis=axis)]