Python 使用矩阵切片使循环更有效

Python 使用矩阵切片使循环更有效,python,performance,Python,Performance,我想让下面的代码更有效率,但我不知道如何做到。我只想使用numpy和本机python库 iterations = 100 aggregation = 0 for i in range(iterations): aggregation += np.sum(np.linalg.norm(dat[dat_filter==i] - dat_points[i], axis=1)) dat是一个nxD矩阵 dat_filter是长度为n的向量,包含从0到num_迭代的标识符 dat_points是

我想让下面的代码更有效率,但我不知道如何做到。我只想使用numpy和本机python库

iterations = 100
aggregation = 0
for i in range(iterations):
    aggregation += np.sum(np.linalg.norm(dat[dat_filter==i] - dat_points[i], axis=1))
dat是一个nxD矩阵 dat_filter是长度为n的向量,包含从0到num_迭代的标识符 dat_points是num_迭代器x D矩阵


基本上,我是在计算一个矩阵Dat(其点属于某个类)与该类点之间的距离,因为数据部分的平方根长度不完全相同,所以向量化问题并不容易。您可以将其部分矢量化,以获得较小的速度提升:

import numpy as np

# Make some data
n = 200000
d = 100
iterations = 2000

np.random.seed(42)
dat = np.random.random((n, d))
dat_filter = np.random.randint(0, n_it, size=n)
dat_points = np.random.random((n_it, d))


def slow(dat, dat_filter, dat_points, iterations):
    aggregation = 0
    for i in range(iterations):
        # Wrote linalg.norm as standard numpy operations,
        # such that numba can be used on the code as well
        aggregation += np.sum(np.sqrt(np.sum((dat[dat_filter==i] - dat_points[i])**2, axis=1)))
    return aggregation

def fast(dat, dat_filter, dat_points, iterations):
    # Rearrange the arrays such that the correct operations are done
    sort_idx = np.argsort(dat_filter)
    filtered_dat_squared_sum = np.sum((dat - dat_points[dat_filter])**2, axis=1)[sort_idx]
    # Count the number of different 'iterations'
    counts = np.unique(dat_filter, return_counts=True)[1]
    aggregation = 0 
    idx = 0 
    for c in counts:
        aggregation += np.sum(np.sqrt(filtered_dat_squared_sum[idx:idx+c]))
        idx += c
    return aggregation
时间:

In [1]: %timeit slow(dat, dat_filter, dat_points, n_it)       
3.47 s ± 314 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [2]: %timeit fast(dat, dat_filter, dat_points, n_it)     
846 ms ± 81.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

使用带有
slow
功能的numba会稍微加快速度,但仍然不如
fast
方法快。使用
fast
功能的Numba使我测试的矩阵大小的调用变慢。

太棒了!那真的很有帮助。谢谢