Python 改进比较算法np.packbits(A==A[:,None],axis=1)的性能

Python 改进比较算法np.packbits(A==A[:,None],axis=1)的性能,python,arrays,algorithm,numpy,boolean,Python,Arrays,Algorithm,Numpy,Boolean,我希望内存优化np.packbits(A==A[:,None],axis=1),其中A是长度n的密集整数数组A==A[:,None]需要大量的n内存,因为生成的布尔数组存储效率低下,每个布尔值消耗1个字节 我编写下面的脚本是为了在一次一节地打包位时获得相同的结果。然而,它大约慢3倍,所以我正在寻找加速的方法。或者,一个内存开销小的更好的算法 注意:这是我之前提出的问题的后续问题 以下为基准测试的可复制代码 import numpy as np from numba import jit @ji

我希望内存优化
np.packbits(A==A[:,None],axis=1)
,其中
A
是长度
n
的密集整数数组
A==A[:,None]
需要大量的
n
内存,因为生成的布尔数组存储效率低下,每个布尔值消耗1个字节

我编写下面的脚本是为了在一次一节地打包位时获得相同的结果。然而,它大约慢3倍,所以我正在寻找加速的方法。或者,一个内存开销小的更好的算法

注意:这是我之前提出的问题的后续问题

以下为基准测试的可复制代码

import numpy as np
from numba import jit

@jit(nopython=True)
def bool2int(x):
    y = 0
    for i, j in enumerate(x):
        if j: y += int(j)<<(7-i)
    return y

@jit(nopython=True)
def compare_elementwise(arr, result, section):
    n = len(arr)
    for row in range(n):
        for col in range(n):

            section[col%8] = arr[row] == arr[col]

            if ((col + 1) % 8 == 0) or (col == (n-1)):
                result[row, col // 8] = bool2int(section)
                section[:] = 0

    return result

n = 10000
A = np.random.randint(0, 1000, n)

result_arr = np.zeros((n, n // 8 if n % 8 == 0 else n // 8 + 1)).astype(np.uint8)
selection_arr = np.zeros(8).astype(np.uint8)

# memory efficient version, but slow
packed = compare_elementwise(A, result_arr, selection_arr)

# memory inefficient version, but fast
packed2 = np.packbits(A == A[:, None], axis=1)

assert (packed == packed2).all()

%timeit compare_elementwise(A, result_arr, selection_arr)  # 1.6 seconds
%timeit np.packbits(A == A[:, None], axis=1)  # 0.460 second
将numpy导入为np
从numba导入jit
@jit(nopython=True)
def bool2int(x):
y=0
对于枚举(x)中的i,j:

如果j:y+=int(j)这里有一个比numpy快3倍的解决方案(a.size必须是8的倍数;见下文):

这是因为阵列只扫描一次,而您需要多次扫描, 而且大多数术语都是空的

In [122]: %timeit np.packbits(A == A[:, None], axis=1)
389 ms ± 57.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [123]: %timeit comp(A)
123 ms ± 24.4 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
如果
a.size%8>0
,则查找信息的成本将更高。在这种情况下,最好的方法是用一些(在
范围(7)
中)零填充初始数组

为了完整性,填充可以这样做:

if A.size % 8 != 0: A = np.pad(A, (0, 8 - A.size % 8), 'constant', constant_values=0)
if A.size % 8 != 0: A = np.pad(A, (0, 8 - A.size % 8), 'constant', constant_values=0)