Python 在numpy nd数组中查找重复值

Python 在numpy nd数组中查找重复值,python,numpy,Python,Numpy,我的每个数据样本都是一个形状的numpy数组,例如(100,100,9),我将其中10个连接到一个形状(10,100,100,9)的数组foo。在10个数据样本中,我想找到重复值的索引。例如,如果foo[0,42,42,3]=0.72和foo[0,42,42,7]=0.72,我想要一个反映这一点的输出。这样做的有效方式是什么 我正在考虑一个形状为(1001009)的布尔输出数组,但是有没有比循环更好的方法来比较每个数据样本(数据样本数的二次运行时(10))?这里有一个在每个样本上使用argsor

我的每个数据样本都是一个形状的numpy数组,例如(100,100,9),我将其中10个连接到一个形状(10,100,100,9)的数组
foo
。在10个数据样本中,我想找到重复值的索引。例如,如果
foo[0,42,42,3]=0.72
foo[0,42,42,7]=0.72
,我想要一个反映这一点的输出。这样做的有效方式是什么


我正在考虑一个形状为(1001009)的布尔输出数组,但是有没有比循环更好的方法来比较每个数据样本(数据样本数的二次运行时(10))?

这里有一个在每个样本上使用
argsort
的解决方案。不漂亮,不快,但能胜任

import numpy as np
from timeit import timeit

def dupl(a, axis=0, make_dict=True):
    a = np.moveaxis(a, axis, -1)
    i = np.argsort(a, axis=-1, kind='mergesort')
    ai = a[tuple(np.ogrid[tuple(map(slice, a.shape))][:-1]) + (i,)]
    same = np.zeros(a.shape[:-1] + (a.shape[-1]+1,), bool)
    same[..., 1:-1] = np.diff(ai, axis=-1) == 0
    uniqs = np.where((same[..., 1:] & ~same[..., :-1]).ravel())[0]
    same = (same[...,1:]|same[...,:-1]).ravel()
    reps = np.split(i.ravel()[same], np.cumsum(same)[uniqs[1:]-1])
    grps = np.searchsorted(uniqs, np.arange(0, same.size, a.shape[-1]))
    keys = ai.ravel()[uniqs]
    if make_dict:
        result = np.empty(a.shape[:-1], object)
        result.ravel()[:] = [dict(zip(*p)) for p in np.split(
                np.array([keys, reps], object), grps[1:], axis=-1)]
        return result
    else:
        return keys, reps, grps

a = np.random.randint(0,10,(10,100,100,9))
axis = 0
result = dupl(a, axis)

print('shape, axis, time (sec) for 10 trials:', 
      a.shape, axis, timeit(lambda: dupl(a, axis=axis), number=10))
print('same without creating dict:', 
      a.shape, axis, timeit(lambda: dupl(a, axis=axis, make_dict=False),
                            number=10))

#check
print("checking result")
am = np.moveaxis(a, axis, -1)
for af, df in zip(am.reshape(-1, am.shape[-1]), result.ravel()):
    assert len(set(af)) + sum(map(len, df.values())) == len(df) + am.shape[-1]
    for k, v in df.items():
        assert np.all(np.where(af == k)[0] == v)
print("no errors")
印刷品:

shape, axis, time (sec) for 10 trials: (10, 100, 100, 9) 0 5.328339613042772
same without creating dict: (10, 100, 100, 9) 0 2.568383438978344
checking result
no errors

在下面的代码段中,
dups
是所需的结果:一个布尔数组,显示哪些索引是重复的。还有一个delta阈值,因此值之间的任何差异是您只想标记任何具有重复值的值,还是希望使用数据值作为键的字典,是否将索引复制为字典值?@James问题属于一般性问题,没有指定返回的精确数据,以避免限制可能的解决方案,但我认为是一个布尔数组,它只是通过索引标记重复项(如上所述)。这段代码中到处都有气味,“还有这么多杂乱无章的分类,好像这是在试图降低效率。”波尔兹曼的大脑有点刺耳,你不觉得吗?此外,与您的不同,这具有合理的复杂性,不是O(nk^2),而是O(n(logn/k+klogk))。我承认这对你的眼睛不容易,但不要仅仅因为它超出了你的能力就去尝试它。
delta = 0.
dups = np.zeros(foo.shape[1:], dtype=bool)
for i in xrange(foo.shape[0]):
    for j in xrange(foo.shape[0]):
        if i==j: continue
        dups += abs(foo[i] - foo[j]) <= delta