Python 分区索引上的组argmax/argmin(单位:numpy)

Python 分区索引上的组argmax/argmin(单位:numpy),python,numpy,Python,Numpy,Numpy的ufuncs有一个方法可以在数组中的连续分区上运行它们。因此,与其写: import numpy as np a = np.array([4, 0, 6, 8, 0, 9, 8, 5, 4, 9]) split_at = [4, 5] maxima = [max(subarray for subarray in np.split(a, split_at)] 我可以写: maxima = np.maximum.reduceat(a, np.hstack([0, split_at]))

Numpy的
ufunc
s有一个方法可以在数组中的连续分区上运行它们。因此,与其写:

import numpy as np
a = np.array([4, 0, 6, 8, 0, 9, 8, 5, 4, 9])
split_at = [4, 5]
maxima = [max(subarray for subarray in np.split(a, split_at)]
我可以写:

maxima = np.maximum.reduceat(a, np.hstack([0, split_at]))
这两个函数都将返回最大值,单位为切片
a[0:4]
a[4:5]
a[5:10]
,即
[8,0,9]

我希望执行类似的函数,注意我只希望每个分区中有一个最大索引:
[3,4,5]
,上面的
a
split_at
(尽管索引5和9在最后一组中都获得了最大值),正如

np.hstack([0, split_at]) + [np.argmax(subarray) for subarray in np.split(a, split_at)]

我将在下面发布一个可能的解决方案,但希望看到一个不在组上创建索引的矢量化解决方案。

此解决方案涉及在组上构建索引(
[0,0,0,0,1,2,2,2]
在上面的示例中)

然后我们可以使用:

maxima = np.maximum.reduceat(a, np.hstack([0, split_at]))
all_argmax = np.flatnonzero(np.repeat(maxima, group_lengths) == a)
result = np.empty(len(group_lengths), dtype='i')
result[index[all_argmax[::-1]]] = all_argmax[::-1]
结果中获取
[3,4,5]
[::-1]
s确保我们在每个组中获得第一个而不是最后一个argmax


这取决于这样一个事实,即fancy赋值中的最后一个索引决定了赋值@seberg(更安全的选择是使用
result=all\u argmax[np.unique(index[all\u argmax],return\u index=True)[1]
,这涉及到对
len(maxima)~n组
元素的排序).

受这个问题的启发,我在软件包中添加了argmin/max功能。下面是相应测试的样子。请注意,密钥可以是任何顺序(以及npi支持的任何类型):


它的算法复杂性类似于您的
np。独特的
解决方案,但根本不涉及
索引
数组。一旦你有了
all\u argmax
你就可以直接做:
all\u argmax[np.searchsorted(all\u argmax,np.hstack([0,split\u at])]
。谢谢,@Jaime,
searchsorted
总是出现在我想不到的地方。@Jaime你提出的解决方案不会给我与
np unique
实现相同的结果。我遗漏了什么吗?是的,我写的代码显然是错误的:
all_argmax
是一个未排序的数组,因此不能对其调用
np.searchsorted
。我没有运行它,所以bug可能仍然存在,但我认为最里面的
all_argmax
调用缺少
np.nonzero()
,即:
all_argmax[np.searchsorted(np.nonzero(all_argmax)),np.hstack([0,split_at])
maxima = np.maximum.reduceat(a, np.hstack([0, split_at]))
all_argmax = np.flatnonzero(np.repeat(maxima, group_lengths) == a)
result = np.empty(len(group_lengths), dtype='i')
result[index[all_argmax[::-1]]] = all_argmax[::-1]
def test_argmin():
    keys   = [2, 0, 0, 1, 1, 2, 2, 2, 2, 2]
    values = [4, 5, 6, 8, 0, 9, 8, 5, 4, 9]
    unique, amin = group_by(keys).argmin(values)
    npt.assert_equal(unique, [0, 1, 2])
    npt.assert_equal(amin,   [1, 4, 0])