Python 什么';根据元素在另一个数组中出现的次数,删除NumPy数组中元素的最有效方法是什么?

Python 什么';根据元素在另一个数组中出现的次数,删除NumPy数组中元素的最有效方法是什么?,python,arrays,numpy,Python,Arrays,Numpy,假设我有两个Numpy阵列: a=np.array([1,2,2,3,3]) b=np.数组([2,2,3]) 我想从a中删除b中的所有元素,删除次数与b中相同。即 diff(a,b) >>>np.数组([1,3,3]) 请注意,对于我的用例,b将始终是a的一个子集,并且两者都可能是无序的,但是像set-like这样的方法不会将其剪切,因为删除每个元素一定的次数很重要 我当前的惰性解决方案如下所示: def差异(a、b): 对于b中的el: idx=(el==a).argmax() 如果a[

假设我有两个Numpy阵列:

a=np.array([1,2,2,3,3])
b=np.数组([2,2,3])
我想从
a
中删除
b
中的所有元素,删除次数与
b
中相同。即

diff(a,b)
>>>np.数组([1,3,3])
请注意,对于我的用例,
b
将始终是
a
的一个子集,并且两者都可能是无序的,但是像set-like这样的方法不会将其剪切,因为删除每个元素一定的次数很重要

我当前的惰性解决方案如下所示:

def差异(a、b):
对于b中的el:
idx=(el==a).argmax()
如果a[idx]==el:
a=np.delete(a,idx)
归还

但我想知道是否有更高性能或更紧凑的“numpy式”写作方式?

您的方法

def dedup_reference(a, b):
    for el in b:
        idx = (el == a).argmax()
        if a[idx] == el:
            a = np.delete(a, idx)
    return a
def dedup_unique(arr, sel):
    d_arr = dict(zip(*np.unique(arr, return_counts=True)))
    d_sel = dict(zip(*np.unique(sel, return_counts=True)))
    d = {k: v - d_sel.get(k, 0) for k, v in d_arr.items()}
    res = np.empty(sum(d.values()), dtype=arr.dtype)
    idx = 0
    for k, count in d.items():
        res[idx:idx+count] = k
        idx += count
    return res
扫描方法需要输入排序:

def dedup_scan(arr, sel):
    arr.sort()
    sel.sort()
    mask = np.ones_like(arr, dtype=np.bool)
    sel_idx = 0
    for i, x in enumerate(arr):
        if sel_idx == sel.size:
            break
        if x == sel[sel_idx]:
            mask[i] = False
            sel_idx += 1
    return arr[mask]
np.独特的
计数方法

def dedup_reference(a, b):
    for el in b:
        idx = (el == a).argmax()
        if a[idx] == el:
            a = np.delete(a, idx)
    return a
def dedup_unique(arr, sel):
    d_arr = dict(zip(*np.unique(arr, return_counts=True)))
    d_sel = dict(zip(*np.unique(sel, return_counts=True)))
    d = {k: v - d_sel.get(k, 0) for k, v in d_arr.items()}
    res = np.empty(sum(d.values()), dtype=arr.dtype)
    idx = 0
    for k, count in d.items():
        res[idx:idx+count] = k
        idx += count
    return res
通过巧妙地使用numpy集函数(例如,
np.inad
),您或许可以实现与上述相同的功能,但我认为这并不比仅仅使用字典快


下面是一个懒散的基准测试尝试(更新后包括@Divakar的
diff_v2
diff_v3
方法):

外卖:

  • 重复数据消除\u参考
    随着重复数据数量的增加,速度会显著减慢
  • 重复数据消除\u unique
    在值范围较小时速度最快
    diff_v3
    速度非常快,不依赖于值的范围
  • 阵列复制时间可以忽略不计
  • 字典很酷

性能特征强烈依赖于数据量(未测试)和数据的统计分布。我建议您使用自己的数据测试这些方法,并选择最快的方法。请注意,不同的解决方案产生不同的输出,并对输入做出不同的假设。

以下是基于-

进一步优化

让我们求助于NumPy来了解groupby cumcount部分-

# Perform groupby cumcount on sorted array
def groupby_cumcount(idx):
    mask = np.r_[False,idx[:-1]==idx[1:],False]
    ids = mask[:-1].cumsum()
    count = np.diff(np.flatnonzero(~mask))
    return ids - np.repeat(ids[~mask[:-1]],count)

def diff_v3(a, b):
    # Get sorted orders
    sidx = a.argsort(kind='stable')
    A = a[sidx]
    
    # Get searchsorted indices per sorted order
    idx = np.searchsorted(A,b)
    
    # Get increments
    idx = np.sort(idx)
    inc = groupby_cumcount(idx)
    
    # Delete elemnents off traced back positions
    return np.delete(a,sidx[idx+inc])
标杆管理 使用
10000
元素的设置,其中
a
b
重复次数为
a
的一半

In [52]: np.random.seed(0)
    ...: a = np.random.randint(0,5000,10000)
    ...: b = a[np.random.choice(len(a), 5000,replace=False)]

In [53]: %timeit diff(a,b)
    ...: %timeit diff_v2(a,b)
    ...: %timeit diff_v3(a,b)
108 ms ± 821 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
3.85 ms ± 53.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1.89 ms ± 15.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
接下来,关于
100000
元素-

In [54]: np.random.seed(0)
    ...: a = np.random.randint(0,50000,100000)
    ...: b = a[np.random.choice(len(a), 50000,replace=False)]

In [55]: %timeit diff(a,b)
    ...: %timeit diff_v2(a,b)
    ...: %timeit diff_v3(a,b)
4.45 s ± 20.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
37.5 ms ± 661 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
28 ms ± 122 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
用于正数和排序输出

我们可以使用-


这是一种类似的方法,但比@Divakar的方法要快一点(在撰写本文时,可能会有所改变…)

更新:@MateenUlhaq的
重复数据消除\u unique的相应计时:

7.986748410039581
81.83312350302003

请注意,此函数产生的结果与Divakar和我的结果不同(至少不是微不足道的)。

有序:您可以对
for
循环并切换掩码内的位,然后将该掩码应用于numpy数组。可能比使用
np重复创建数组快得多。删除
。是否需要对数组进行排序?它会无序吗?或者您总是需要从左向右删除吗?无序:(1):np.unique和counts,然后使用其中的值重建数组。(2) :使用类似枚举的方法,将每个数组项配对为唯一索引,然后将设置的差异与结果进行比较。@MateenUlhaq当然可以,但我更希望通过仅使用numpy而不使用for循环解决方案来加快速度。如果可能的话。@MateenUlhaq也可能是无序的,更新了帖子。准备挑战吗?顺便说一句,如果可以在这里提问的话,我尝试了
benchit
,但在导入时它死了,抱怨
qtagg
(IIRC)后端。它必须是那个后端还是我可以使用另一个后端?@PaulPanzer你在试笔记本吗?另外,在导入
benchit
之前是否导入matplotlib?是的,时间非常接近:)不,普通的python repl。导入或不导入
matplotlib
似乎没有什么区别。@PaulPanzer如果您能够导入
benchit
,您能从
benchit.print_specs()
中获取输出吗?是窗户吗?看来我需要做更多的测试。
import numpy as np

def pp():
    if a.dtype.kind == "i":
        small = np.iinfo(a.dtype).min
    else:
        small = -np.inf
    ba = np.concatenate([[small],b,a])
    idx = ba.argsort(kind="stable")
    aux = np.where(idx<=b.size,-1,1)
    aux = aux.cumsum()
    valid = aux==np.maximum.accumulate(aux)
    valid[0] = False
    valid[1:] &= valid[:-1]
    aux2 = np.zeros(ba.size,bool)
    aux2[idx[valid]] = True
    return ba[aux2.nonzero()]

def groupby_cumcount(idx):
    mask = np.r_[False,idx[:-1]==idx[1:],False]
    ids = mask[:-1].cumsum()
    count = np.diff(np.flatnonzero(~mask))
    return ids - np.repeat(ids[~mask[:-1]],count)

def diff_v3():
    # Get sorted orders
    sidx = a.argsort(kind='stable')
    A = a[sidx]
    
    # Get searchsorted indices per sorted order
    idx = np.searchsorted(A,b)
    
    # Get increments
    idx = np.sort(idx)
    inc = groupby_cumcount(idx)
    
    # Delete elemnents off traced back positions
    return np.delete(a,sidx[idx+inc])

np.random.seed(0)
a = np.random.randint(0,5000,10000)
b = a[np.random.choice(len(a), 5000,replace=False)]

from timeit import timeit

print(timeit(pp,number=100)*10)
print(timeit(diff_v3,number=100)*10)
print((pp() == diff_v3()).all())

np.random.seed(0)
a = np.random.randint(0,50000,100000)
b = a[np.random.choice(len(a), 50000,replace=False)]

print(timeit(pp,number=10)*100)
print(timeit(diff_v3,number=10)*100)
print((pp() == diff_v3()).all())
1.4644702401710674
1.6345531499246135
True
22.230969095835462
24.67835019924678
True
7.986748410039581
81.83312350302003