具有缩减的最大值的Numpy索引-Numpy.argmax.reduceat

具有缩减的最大值的Numpy索引-Numpy.argmax.reduceat,numpy,vectorization,reduction,argmax,numpy-ufunc,Numpy,Vectorization,Reduction,Argmax,Numpy Ufunc,我有一个平面阵列b: a = numpy.array([0, 1, 1, 2, 3, 1, 2]) 以及标记每个“块”开头的索引数组c: 我知道我可以通过减少找到每个“块”的最大值: m = numpy.maximum.reduceat(a,b) >>> array([2, 3], dtype=int32) 但是。。。有没有一种方法可以通过矢量化操作(无列表、循环)找到块中最大的索引(如numpy.argmax) 涉及的步骤: 按限制偏移量偏移组中的所有图元。对它们进行全

我有一个平面阵列
b

a = numpy.array([0, 1, 1, 2, 3, 1, 2])
以及标记每个“块”开头的索引数组
c

我知道我可以通过减少找到每个“块”的最大值:

m = numpy.maximum.reduceat(a,b)
>>> array([2, 3], dtype=int32)
但是。。。有没有一种方法可以通过矢量化操作(无列表、循环)找到块
中最大
的索引(如
numpy.argmax

涉及的步骤:

  • 按限制偏移量偏移组中的所有图元。对它们进行全局排序,从而限制每个组停留在它们的位置,但对每个组内的元素进行排序

  • 在排序数组中,我们将查找最后一个元素,即组最大值。它们的索引将是组长度向下偏移后的argmax

因此,矢量化的实现将是-

def numpy_argmax_reduceat(a, b):
    n = a.max()+1  # limit-offset
    grp_count = np.append(b[1:] - b[:-1], a.size - b[-1])
    shift = n*np.repeat(np.arange(grp_count.size), grp_count)
    sortidx = (a+shift).argsort()
    grp_shifted_argmax = np.append(b[1:],a.size)-1
    return sortidx[grp_shifted_argmax] - b
作为一个小的调整,也可能是更快的调整,我们可以使用
cumsum
创建
shift
,这样就有了早期方法的变体,如下所示-

def numpy_argmax_reduceat_v2(a, b):
    n = a.max()+1  # limit-offset
    id_arr = np.zeros(a.size,dtype=int)
    id_arr[b[1:]] = 1
    shift = n*id_arr.cumsum()
    sortidx = (a+shift).argsort()
    grp_shifted_argmax = np.append(b[1:],a.size)-1
    return sortidx[grp_shifted_argmax] - b

暂时删除了我的问题,因为我认为我有一个答案:
numpy.argmax(numpy.equal.outer(m,a),axis=1)
,但这不适用于在许多地方出现相同max的示例…例如在这个数组上:
a=numpy.array([0,1,1,3,3,1,2])
,在两个块中出现相同的最大值
3
。问题是
np.max
是一个
ufunc
reduceat
-有效地迭代数组,一次比较两个值。但是
np.max
np.argmax
是同时在整个数组上运行的函数。他们不是
ufunc
@hpaulj,是的,我知道。我在问是否有人能想出一个具有相同行为的解决方案。在我的例子中,两种解决方案都很有效,因为我已经有了先前操作的
shift
。回答得好。嘿,几个月前你已经回答了我的问题。你能看看这个帖子吗:。这与这个问题有关。
def numpy_argmax_reduceat_v2(a, b):
    n = a.max()+1  # limit-offset
    id_arr = np.zeros(a.size,dtype=int)
    id_arr[b[1:]] = 1
    shift = n*id_arr.cumsum()
    sortidx = (a+shift).argsort()
    grp_shifted_argmax = np.append(b[1:],a.size)-1
    return sortidx[grp_shifted_argmax] - b