Python Cython的加速比不是';没有预期的那么大

Python Cython的加速比不是';没有预期的那么大,python,numpy,cython,Python,Numpy,Cython,我编写了一个Python函数,用于计算大量(N~10^3)粒子之间的成对电磁相互作用,并将结果存储在NxN complex128数组中。它可以运行,但它是一个较大程序中最慢的部分,当N=900时需要大约40秒[corrected]。原始代码如下所示: import numpy as np def interaction(s,alpha,kprop): # s is an Nx3 real array # alpha is comp

我编写了一个Python函数,用于计算大量(N~10^3)粒子之间的成对电磁相互作用,并将结果存储在NxN complex128数组中。它可以运行,但它是一个较大程序中最慢的部分,当N=900时需要大约40秒[corrected]。原始代码如下所示:

import numpy as np
def interaction(s,alpha,kprop): # s is an Nx3 real array 
                                # alpha is complex
                                # kprop is float

    ndipoles = s.shape[0]

    Amat = np.zeros((ndipoles,3, ndipoles, 3), dtype=np.complex128)
    I = np.array([[1,0,0],[0,1,0],[0,0,1]])
    im = complex(0,1)

    k2 = kprop*kprop

    for i in range(ndipoles):
        xi = s[i,:]
        for j in range(ndipoles):
            if i != j:
                xj = s[j,:]
                dx = xi-xj
                R = np.sqrt(dx.dot(dx))
                n = dx/R
                kR = kprop*R
                kR2 = kR*kR
                A = ((1./kR2) - im/kR)
                nxn = np.outer(n, n)
                nxn = (3*A-1)*nxn + (1-A)*I
                nxn *= -alpha*(k2*np.exp(im*kR))/R
            else:
                nxn = I

            Amat[i,:,j,:] = nxn

    return(Amat.reshape((3*ndipoles,3*ndipoles)))
我以前从未使用过Cython,但这似乎是我加快速度的一个很好的起点,所以我几乎盲目地采用了在线教程中的技术。我得到了一些加速(30秒对40秒),但没有我预期的那么戏剧性,所以我想知道我是做错了什么还是错过了一个关键步骤。以下是我对上述例行程序的最佳尝试:

import numpy as np
cimport numpy as np

DTYPE = np.complex128
ctypedef np.complex128_t DTYPE_t

def interaction(np.ndarray s, DTYPE_t alpha, float kprop):

    cdef float k2 = kprop*kprop
    cdef int i,j
    cdef np.ndarray xi, xj, dx, n, nxn
    cdef float R, kR, kR2
    cdef DTYPE_t A

    cdef int ndipoles = s.shape[0]
    cdef np.ndarray Amat = np.zeros((ndipoles,3, ndipoles, 3), dtype=DTYPE)
    cdef np.ndarray I = np.array([[1,0,0],[0,1,0],[0,0,1]])
    cdef DTYPE_t im = complex(0,1)

    for i in range(ndipoles):
        xi = s[i,:]
        for j in range(ndipoles):
            if i != j:
                xj = s[j,:]
                dx = xi-xj
                R = np.sqrt(dx.dot(dx))
                n = dx/R
                kR = kprop*R
                kR2 = kR*kR
                A = ((1./kR2) - im/kR)
                nxn = np.outer(n, n)
                nxn = (3*A-1)*nxn + (1-A)*I
                nxn *= -alpha*(k2*np.exp(im*kR))/R
            else:
                nxn = I

            Amat[i,:,j,:] = nxn

    return(Amat.reshape((3*ndipoles,3*ndipoles)))

NumPy的真正威力在于以矢量化的方式跨大量元素执行操作,而不是在循环中分散的块中使用该操作。在本例中,您使用两个嵌套循环和一个IF条件语句。我建议扩展中间数组的维度,这将发挥作用,因此可以一次性对所有元素使用相同的操作,而不是循环中的小块数据

对于扩展尺寸,可以使用。因此,遵循这样一个前提的向量化实现如下所示-

def vectorized_interaction(s,alpha,kprop):

    im = complex(0,1)
    I = np.array([[1,0,0],[0,1,0],[0,0,1]])
    k2 = kprop*kprop

    # Vectorized calculations for dx, R, n, kR, A
    sd = s[:,None] - s 
    Rv = np.sqrt((sd**2).sum(2))
    nv = sd/Rv[:,:,None]
    kRv = Rv*kprop
    Av = (1./(kRv*kRv)) - im/kRv

    # Vectorized calculation for: "nxn = np.outer(n, n)"
    nxnv = nv[:,:,:,None]*nv[:,:,None,:]

    # Vectorized calculation for: "(3*A-1)*nxn + (1-A)*I"
    P = (3*Av[:,:,None,None]-1)*nxnv + (1-Av[:,:,None,None])*I

    # Vectorized calculation for: "-alpha*(k2*np.exp(im*kR))/R"    
    multv = -alpha*(k2*np.exp(im*kRv))/Rv

    # Vectorized calculation for: "nxn *= -alpha*(k2*np.exp(im*kR))/R"   
    outv = P*multv[:,:,None,None]


    # Simulate ELSE part of the conditional statement"if i != j:" 
    # with masked setting to I on the last two dimensions
    outv[np.eye((N),dtype=bool)] = I

    return outv.transpose(0,2,1,3).reshape(N*3,-1)
运行时测试和输出验证-

案例1:

案例2:

案例3:


Numpy是一个C库。用BLAS做代数,所以速度很快。我真的不明白cython的内部结构是如何工作的,但是作为一个已经是numpy的C代码,速度的提高体现在任何“不是numpy”的东西上。我假设嵌套循环中有足够多的逐行操作需要直接调用Python解释器,因此这些行可能是相对于Numpy的主要成本,但可能不是。您可以尝试键入Numpy数组,以便编译器知道数组中的类型。不过,我不确定差别会有多大。您可能希望在python代码上运行探查器,以查看实际速度下降的地方。如果大部分时间都花在numpy例程上,那么使用cython就不会有多大收获。我猜(但不是想当然地认为)最大的时差在这个范围内(偶极子)。尝试使用更快的numpy.arange(),看看你离cython有多近。也许可以研究分析/注释你的代码,看看瓶颈在哪里:这非常有效,对我来说也是一个很好的矢量化教育。但有一个问题:警告是用零除以生成的,可能对应于i=j的情况,因为R等于0。有没有办法在不产生警告的情况下完成同样的事情?我意识到数组中受影响的元素在返回之前的最后一行被覆盖,所以我可能不得不强迫自己忽略它们。另一个评论:矢量化版本似乎比原始版本需要更多的内存,因此当我将N提高到9000时,进程似乎至少需要50 GB(我的机器中有16 GB)。也许这是矢量化不可避免的折衷。@GrantPetty是的,在我们最终设置这些值时,您需要忽略这些警告。另外,您是对的,矢量化的折衷是您使用了更多的内存,但这就是为什么矢量化这么好,因为它存储了所有这些元素在内存中,一次完成所有操作。
In [703]: N = 10
     ...: s = np.random.rand(N,3) + complex(0,1)*np.random.rand(N,3)
     ...: alpha = 3j
     ...: kprop = 5.4
     ...: 

In [704]: out_org = interaction(s,alpha,kprop)
     ...: out_vect = vectorized_interaction(s,alpha,kprop)
     ...: print np.allclose(np.real(out_org),np.real(out_vect))
     ...: print np.allclose(np.imag(out_org),np.imag(out_vect))
     ...: 
True
True

In [705]: %timeit interaction(s,alpha,kprop)
100 loops, best of 3: 7.6 ms per loop

In [706]: %timeit vectorized_interaction(s,alpha,kprop)
1000 loops, best of 3: 304 µs per loop
In [707]: N = 100
     ...: s = np.random.rand(N,3) + complex(0,1)*np.random.rand(N,3)
     ...: alpha = 3j
     ...: kprop = 5.4
     ...: 

In [708]: out_org = interaction(s,alpha,kprop)
     ...: out_vect = vectorized_interaction(s,alpha,kprop)
     ...: print np.allclose(np.real(out_org),np.real(out_vect))
     ...: print np.allclose(np.imag(out_org),np.imag(out_vect))
     ...: 
True
True

In [709]: %timeit interaction(s,alpha,kprop)
1 loops, best of 3: 826 ms per loop

In [710]: %timeit vectorized_interaction(s,alpha,kprop)
100 loops, best of 3: 14 ms per loop
In [711]: N = 900
     ...: s = np.random.rand(N,3) + complex(0,1)*np.random.rand(N,3)
     ...: alpha = 3j
     ...: kprop = 5.4
     ...: 

In [712]: out_org = interaction(s,alpha,kprop)
     ...: out_vect = vectorized_interaction(s,alpha,kprop)
     ...: print np.allclose(np.real(out_org),np.real(out_vect))
     ...: print np.allclose(np.imag(out_org),np.imag(out_vect))
     ...: 
True
True

In [713]: %timeit interaction(s,alpha,kprop)
1 loops, best of 3: 1min 7s per loop

In [714]: %timeit vectorized_interaction(s,alpha,kprop)
1 loops, best of 3: 1.59 s per loop