Python numpy 3d阵列-沿给定轴选择最大元素

Python numpy 3d阵列-沿给定轴选择最大元素,python,arrays,numpy,multidimensional-array,Python,Arrays,Numpy,Multidimensional Array,我有一个numpy数组,它的维数(x,y,z)=(5,50,4)。对于每个(x,y)对,我想找到沿z轴的最大值的索引。此索引在范围(4)内。我想选择所有这些“最大”元素并将它们设置为1。然后,我想选择所有其他元素并将它们设置为零 换一种方式来解释,我想看看“z”方向上的所有向量(这些向量总共有x*y)。我想将最大元素设置为1,将所有其他元素设置为0。例如,向量(0.25,0.1,0.5,0.15)将变为(0,0,1,0) 我试过很多不同的方法。argmax函数似乎应该有所帮助。但是如何使用它正确

我有一个numpy数组,它的维数(x,y,z)=(5,50,4)。对于每个(x,y)对,我想找到沿z轴的最大值的索引。此索引在范围(4)内。我想选择所有这些“最大”元素并将它们设置为1。然后,我想选择所有其他元素并将它们设置为零

换一种方式来解释,我想看看“z”方向上的所有向量(这些向量总共有x*y)。我想将最大元素设置为1,将所有其他元素设置为0。例如,向量(0.25,0.1,0.5,0.15)将变为(0,0,1,0)

我试过很多不同的方法。argmax函数似乎应该有所帮助。但是如何使用它正确地选择元素呢?我试过像

x = data
i = x.argmax(axis = 2)
x[i] # shape = (5, 50, 50, 4)
x[:,:,i] # shape = (5, 50, 5, 50)
x[np.unravel_index(i), x.shape] # shape = (5, 50)
最后一个使用np.unravel_索引的索引具有正确的形状,但选定的索引不是沿z轴的最大值。所以我有点麻烦。如果有人能帮上忙,那就太棒了。谢谢


编辑:下面是我找到的一种方法。但是如果有人有更快的,请告诉我

def fix_vector(a):
    i = a.argmax()
    a = a*0
    a[i] = 1
    return a

y = np.apply_along_axis(fix_vector, axis=2, arr=x)
如果可能的话,我真的很想优化它,因为我多次调用这个函数


编辑:感谢DSM提供了一个很好的解决方案。下面是一个小示例数据集,如注释中所要求的

data = np.random.random((3,5,4))
desired_output = np.apply_along_axis(fix_vector, axis=2, arr=data)

这使用了我在上面发布的fix_向量函数,但是DSM的解决方案更快。再次感谢

这不是特别优雅,但是:

def faster(x):
    d = x.reshape(-1, x.shape[-1])
    d2 = np.zeros_like(d)
    d2[np.arange(len(d2)), d.argmax(1)] = 1
    d2 = d2.reshape(x.shape)
    return d2
似乎比
fix\u vector
方法快一点。例如:

>>> x,y,z = 5,50,4
>>> data = np.random.random((x,y,z))
>>> np.allclose(orig(data), faster(data))
True
>>> %timeit -n 1000 orig(data)
1000 loops, best of 3: 5.77 ms per loop
>>> %timeit -n 1000 faster(data)
1000 loops, best of 3: 36.6 µs per loop

这可以通过
numpy.where

import numpy as np
a = np.array([[[ 0.25,  0.10 ,  0.50 ,  0.15],
               [ 0.50,  0.60 ,  0.40 ,  0.30]],
              [[ 0.25,  0.50 ,  0.20 ,  0.70],
               [ 0.80,  0.10 ,  0.50 ,  0.15]]])
沿最后一个轴查找最大值

b = a.max(-1)    #shape is (2,2)
b
添加一个轴,使其穿过
a
并创建一个布尔数组

condition = a == b[..., np.newaxis]
使用
numpy.where
进行替换

c = np.where(condition, 1, 0)

>>> c
array([[[0, 0, 1, 0],
        [0, 1, 0, 0]],

       [[0, 0, 0, 1],
        [1, 0, 0, 0]]])


所以我不满意-我见过这样的情况,
重塑
比添加一个新轴要快,所以我玩了一下。似乎
numpy。其中
本身有点慢。使用带有赋值的布尔索引可以提供与@DSM几乎相同的性能

def h(a):
    b = a.max(-1)
    condition = a == b[..., np.newaxis]
    a[condition] = 1
    a[np.logical_not(condition)] = 0
    return a

样本
x
(2,3,4)
形状:

In [934]: x
Out[934]: 
array([[[ 5,  7,  4, 15],
        [ 9, 23, 14,  9],
        [17, 13,  2,  1]],

       [[ 2, 18,  3, 18],
        [10, 12, 23, 14],
        [17, 23, 19, 13]]])
沿最后一个轴的
max
值-足够清晰,一个(2,3)数组:

沿着该轴的坐标

In [936]: I=x.argmax(2)

In [937]: I
Out[937]: 
array([[3, 1, 0],
       [1, 2, 1]], dtype=int32)
使用这些索引索引
x
的一种方法。此时,我正在手动构造0轴和1轴索引。它需要自动化

In [938]: x[np.arange(2)[:,None], np.arange(3)[None,:], I]
Out[938]: 
array([[15, 23, 17],
       [18, 23, 23]])
相同的索引可用于设置另一个数组中的值:

In [939]: y=np.zeros_like(x)   
In [940]: y[np.arange(2)[:,None],np.arange(3)[None,:],I]=x.max(2)
In [941]: y
Out[941]: 
array([[[ 0,  0,  0, 15],
        [ 0, 23,  0,  0],
        [17,  0,  0,  0]],

       [[ 0, 18,  0,  0],
        [ 0,  0, 23,  0],
        [ 0, 23,  0,  0]]])
诀窍是提出一种更简化的生成索引元组的方法:

(np.arange(2)[:,None], np.arange(3)[None,:], I)

(array([[0],
    [1]]), 
 array([[0, 1, 2]]), 
 array([[3, 1, 0],
    [1, 2, 1]], dtype=int32))
np.ix
有帮助,尽管连接两个元组有点难看:

J = np.ix_(range(x.shape[0]), range(x.shape[1])) + (I,)
在展平阵列上建立索引可以更快:

In [998]: J1 = np.ravel_multi_index(J, x.shape)
In [999]: J1
Out[999]: 
array([[ 3,  5,  8],
     [13, 18, 21]], dtype=int32)

In [1000]: x.flat[J1]
Out[1000]: 
array([[15, 23, 17],
   [18, 23, 23]])

In [1001]: y.flat[J1]=2

In [1002]: y
Out[1002]: 
array([[[0, 0, 0, 2],
        [0, 2, 0, 0],
        [2, 0, 0, 0]],

       [[0, 2, 0, 0],
        [0, 0, 2, 0],
        [0, 2, 0, 0]]])
但是我仍然需要构造
J
索引元组

这类似于DSM的解决方案,但坚持三维。将2个尺寸展平,无需进行
xi
构造


是一个通用函数,用于生成类似
J
的索引元组

e、 g


x的最小值和最大值(沿轴1)插入
y
。速度大约是DSM的一半
s
faster

处理后,您应该添加一个小的输入数组示例和所需的结果。可比但稍低(50%)比@DSM慢-看起来它们的缩放比例与数组大小相同。请注意,如果最大值不唯一,这可能会导致一行中出现多个1。但据我们所知,这可能是OP希望在这种情况下发生的事情(这个问题在这一点上是不明确的)。@DSM-您选择的是连续多个最大值的最低索引?感谢您的解决方案,二战。这正是我想要的。仅供参考,对于我的应用程序,我从未遇到两个相同的值,因此两个解决方案都应该有效。感谢您的解决方案,hpaulj。我认为DSM和二战提供了更优雅的解决方案,但看到这一点也很有趣。
J = np.ix_(range(x.shape[0]), range(x.shape[1])) + (I,)
In [998]: J1 = np.ravel_multi_index(J, x.shape)
In [999]: J1
Out[999]: 
array([[ 3,  5,  8],
     [13, 18, 21]], dtype=int32)

In [1000]: x.flat[J1]
Out[1000]: 
array([[15, 23, 17],
   [18, 23, 23]])

In [1001]: y.flat[J1]=2

In [1002]: y
Out[1002]: 
array([[[0, 0, 0, 2],
        [0, 2, 0, 0],
        [2, 0, 0, 0]],

       [[0, 2, 0, 0],
        [0, 0, 2, 0],
        [0, 2, 0, 0]]])
def max_index(x, n, fn=np.argmax):
    # return multidimensional indexing tuple
    # applying fn along axis n
    I = fn(x, axis=n)
    alist = list(np.ix_(*[range(i) for i in I.shape]))
    alist.insert(n, I)
    return tuple(alist)
y = np.zeros_like(x)
I = max_index(x, 1, np.argmin)
y[I] = x[I]
I = max_index(x, 1, np.argmax)
y[I] = x[I]