在Python中检查高维数组在两个数据数组中重叠的有效方法

在Python中检查高维数组在两个数据数组中重叠的有效方法,python,arrays,performance,numpy,set-operations,Python,Arrays,Performance,Numpy,Set Operations,例如,我有两个数据集,train\u数据集的形状是(10000,28,28),val\u数据集的形状是(2000,28,28) 除了使用迭代之外,还有什么有效的方法可以使用numpy数组函数来查找两个ndarray之间的重叠吗?内存允许您这样使用- val_dateset[(train_dataset[:,None] == val_dateset).all(axis=(2,3)).any(0)] 样本运行- In [55]: train_dataset Out[55]: array([[[1

例如,我有两个数据集,
train\u数据集的形状是
(10000,28,28)
val\u数据集的形状是
(2000,28,28)


除了使用迭代之外,还有什么有效的方法可以使用numpy数组函数来查找两个ndarray之间的重叠吗?

内存允许您这样使用-

val_dateset[(train_dataset[:,None] == val_dateset).all(axis=(2,3)).any(0)]
样本运行-

In [55]: train_dataset
Out[55]: 
array([[[1, 1],
        [1, 1]],

       [[1, 0],
        [0, 0]],

       [[0, 0],
        [0, 1]],

       [[0, 1],
        [0, 0]],

       [[1, 1],
        [1, 0]]])

In [56]: val_dateset
Out[56]: 
array([[[0, 1],
        [1, 0]],

       [[1, 1],
        [1, 1]],

       [[0, 0],
        [0, 1]]])

In [57]: val_dateset[(train_dataset[:,None] == val_dateset).all(axis=(2,3)).any(0)]
Out[57]: 
array([[[1, 1],
        [1, 1]],

       [[0, 0],
        [0, 1]]])

如果元素是整数,则可以折叠轴=(1,2)的每个块
将输入数组转换为标量,假设它们是可线性索引的数字,然后有效地使用
np.in1d
np.intersect1d
查找匹配项。

全广播在此生成一个10000*2000*28*28=150 Mo的布尔数组

为了提高效率,您可以:

  • 打包数据,对于200千卡阵列:

    from pylab import *
    N=10000
    a=rand(N,28,28)
    b=a[[randint(0,N,N//5)]]
    
    packedtype='S'+ str(a.size//a.shape[0]*a.dtype.itemsize) # 'S6272' 
    ma=frombuffer(a,packedtype)  # ma.shape=10000
    mb=frombuffer(b,packedtype)  # mb.shape=2000
    
    %timeit a[:,None]==b   : 102 s
    %timeit ma[:,None]==mb   : 800 ms
    allclose((a[:,None]==b).all((2,3)),(ma[:,None]==mb)) : True
    
    延迟字符串比较有助于减少内存,在第一个差异时中断:

    In [31]: %timeit a[:100]==b[:100]
    10000 loops, best of 3: 175 µs per loop
    
    In [32]: %timeit a[:100]==a[:100]
    10000 loops, best of 3: 133 µs per loop
    
    In [34]: %timeit ma[:100]==mb[:100]
    100000 loops, best of 3: 7.55 µs per loop
    
    In [35]: %timeit ma[:100]==ma[:100]
    10000 loops, best of 3: 156 µs per loop
    
这里给出的解决方案带有
(ma[:,None]==mb)。非零()

  • 对于
    (Na+Nb)ln(Na+Nb)
    复杂度,针对
    Na*Nb
    完全比较:

    %timeit in1d(ma,mb).nonzero()  : 590ms 
    
这不是一个很大的收益,但渐进地更好

解决方案 您可以使用返回的索引数组对
b
进行索引,以提取在
a

b[overlap(a,b)]
解释 为了简单起见,我假设您已经从
numpy
导入了本例中的所有内容:

from numpy import *
例如,给定两个nDarray

a = arange(4*2*2).reshape(4,2,2)
b = arange(3*2*2).reshape(3,2,2)
我们重复
a
b
,使它们具有相同的形状

aa = a.repeat(b.shape[0],axis=0)
bb = b.repeat(a.shape[0],axis=0)
然后我们可以简单地比较
aa
bb

c = aa == bb
最后,通过查看
c
th元素的每4个或实际上每个
shape(a)[0]
th元素,获得
b
中元素的索引,这些元素也可以在
a
中找到

cc == c[::a.shape[0]]
最后,我们提取一个只包含元素的索引数组,其中子数组中的所有元素都是
True

c.all(axis=1)[:,0]
在我们的例子中,我们得到

array([True,  True,  True], dtype=bool)
要进行检查,请更改
b的第一个元素

b[0] = array([[50,60],[70,80]])
我们得到了

array([False,  True,  True], dtype=bool)
我学到的一个技巧是使用
np.void
dtype,以便将输入数组中的每一行作为单个元素查看。这允许您将它们视为1D数组,然后可以将其传递给一个或另一个数组

例如:

gen = np.random.RandomState(0)

A = gen.randn(1000, 28, 28)
dupe_idx = gen.choice(A.shape[0], size=200, replace=False)
B = A[dupe_idx]

A_in_B = find_overlap(A, B)

print(np.all(np.where(A_in_B)[0] == np.sort(dupe_idx)))
# True
这种方法比Divakar的内存效率要高得多,因为它不需要广播到
(m,n,…)
布尔数组。事实上,如果
A
B
是行主项,则根本不需要复制


为了比较,我稍微修改了Divakar和B.M.的解决方案

def divakar(A, B):
    A.shape = A.shape[0], -1
    B.shape = B.shape[0], -1
    return (B[:,None] == A).all(axis=(2)).any(0)

def bm(A, B):
    t = 'S' + str(A.size // A.shape[0] * A.dtype.itemsize)
    ma = np.frombuffer(np.ascontiguousarray(A), t)
    mb = np.frombuffer(np.ascontiguousarray(B), t)
    return (mb[:, None] == ma).any(0)
基准: 如您所见,对于小n,B.M.的解决方案比我的略快,但
np.inad
比测试所有元素的相等性(O(n logn)而不是O(n²)复杂度)更具扩展性


对于这种大小的阵列,Divakar的解决方案在我的笔记本电脑上很难实现,因为它需要生成15GB的中间阵列,而我只有8GB的RAM。

这个问题来自谷歌的在线深度学习课程? 以下是我的解决方案:

sum = 0 # number of overlapping rows
for i in range(val_dataset.shape[0]): # iterate over all rows of val_dataset
    overlap = (train_dataset == val_dataset[i,:,:]).all(axis=1).all(axis=1).sum()
    if overlap:
        sum += 1
print(sum)

使用自动广播代替迭代。您可以测试性能差异。

您能解释一下“重叠”的确切含义吗?您正在查找在
train_数据集
val_数据集
中找到的行索引吗?对,我想找出两个数据集中出现的元素(28*28)。如果您试图创建训练和验证数据集,您最好使用scikit learn。它实际上更清晰、更简洁,+1。我建议不要使用numpy import中的
来污染您的命名空间*
太好了!从来都不知道这样的字符串可以实现短路。但是可能会在其中添加一个
ascontiguousarray
,这样它就可以对
a=rand(28,28,N)这样的数组正常工作。在本例中,我只是创建数组,因此它是连续的,但是对于外部数据,
ma=frombuffer(ascontiguousarray(a),packedtype)
更安全。谢谢。谢谢。这个解决方案真的很有帮助。不过,我正试图更好地理解它。使用
np.ascontiguousarray(A.reforme(A.shape[0],-1))
而不是
np.array([x.flatte()表示A中的x])
的原因是什么?我看到其他代码也使用它来做类似的事情?这是一种风格的东西,还是他们做了一些不同的事情?@Barker试着对这两行进行计时,以获得一个相当大的输入数组。首先,列表理解几乎肯定比单个调用
重塑
要慢,特别是当
中有很多行时。另外,
.flatte()
始终返回一个副本(而
.reforme()
.ravel()
仅在必要时返回一个副本),因此您要对
中的每一行创建一个临时副本,然后在调用列表上的
np.array(…)
时创建另一个副本
np.ascontiguousarray
仅在必要时返回一个副本,因此我的行最多创建
a
的一个副本。这是唯一适用于较大数据集的解决方案(本页)。它应该有更多的选票!
def divakar(A, B):
    A.shape = A.shape[0], -1
    B.shape = B.shape[0], -1
    return (B[:,None] == A).all(axis=(2)).any(0)

def bm(A, B):
    t = 'S' + str(A.size // A.shape[0] * A.dtype.itemsize)
    ma = np.frombuffer(np.ascontiguousarray(A), t)
    mb = np.frombuffer(np.ascontiguousarray(B), t)
    return (mb[:, None] == ma).any(0)
In [1]: na = 1000; nb = 200; rowshape = 28, 28

In [2]: %%timeit A = gen.randn(na, *rowshape); idx = gen.choice(na, size=nb, replace=False); B = A[idx]
divakar(A, B)
   ....: 
1 loops, best of 3: 244 ms per loop

In [3]: %%timeit A = gen.randn(na, *rowshape); idx = gen.choice(na, size=nb, replace=False); B = A[idx]
bm(A, B)
   ....: 
100 loops, best of 3: 2.81 ms per loop

In [4]: %%timeit A = gen.randn(na, *rowshape); idx = gen.choice(na, size=nb, replace=False); B = A[idx]
find_overlap(A, B)
   ....: 
100 loops, best of 3: 15 ms per loop
In [5]: na = 10000; nb = 2000; rowshape = 28, 28

In [6]: %%timeit A = gen.randn(na, *rowshape); idx = gen.choice(na, size=nb, replace=False); B = A[idx]
bm(A, B)
   ....: 
1 loops, best of 3: 271 ms per loop

In [7]: %%timeit A = gen.randn(na, *rowshape); idx = gen.choice(na, size=nb, replace=False); B = A[idx]
find_overlap(A, B)
   ....: 
10 loops, best of 3: 123 ms per loop
sum = 0 # number of overlapping rows
for i in range(val_dataset.shape[0]): # iterate over all rows of val_dataset
    overlap = (train_dataset == val_dataset[i,:,:]).all(axis=1).all(axis=1).sum()
    if overlap:
        sum += 1
print(sum)