Warning: file_get_contents(/data/phpspider/zhask/data//catemap/0/performance/5.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Performance cython跑得比numpy慢,用于距离计算_Performance_Numpy_Distance_Cython - Fatal编程技术网

Performance cython跑得比numpy慢,用于距离计算

Performance cython跑得比numpy慢,用于距离计算,performance,numpy,distance,cython,Performance,Numpy,Distance,Cython,我正在努力学习赛昂;然而,我一定是做错了什么。这段测试代码的运行速度比我的矢量化numpy版本慢50倍左右。谁能告诉我为什么我的cython比python慢?谢谢 代码计算R^3中的点loc和R^3中的点阵列points之间的距离 import numpy as np cimport numpy as np import cython cimport cython DTYPE = np.float64 ctypedef np.float64_t DTYPE_t @cython.boundsch

我正在努力学习赛昂;然而,我一定是做错了什么。这段测试代码的运行速度比我的矢量化numpy版本慢50倍左右。谁能告诉我为什么我的cython比python慢?谢谢

代码计算R^3中的点loc和R^3中的点阵列points之间的距离

import numpy as np
cimport numpy as np
import cython
cimport cython

DTYPE = np.float64
ctypedef np.float64_t DTYPE_t
@cython.boundscheck(False) # turn of bounds-checking for entire function
@cython.wraparound(False)
@cython.nonecheck(False)
def distMeasureCython(np.ndarray[DTYPE_t, ndim=2] points, np.ndarray[DTYPE_t, ndim=1] loc):
    cdef unsigned int i
    cdef unsigned int L = points.shape[0]
    cdef np.ndarray[DTYPE_t, ndim=1] d = np.zeros(L)
    for i in xrange(0,L):
        d[i] = np.sqrt((points[i,0] - loc[0])**2 + (points[i,1] - loc[1])**2 + (points[i,2]  - loc[2])**2)
    return d
这是与之比较的numpy代码

from numpy import *
N = 1e6
points = random.uniform(0,1,(N,3))
loc = random.uniform(0,1,(3))

def distMeasureNumpy(points,loc):
    d = points - loc
    d = sqrt(sum(d*d,axis=1))
    return d

numpy/python版本大约需要44毫秒,cython版本大约需要2秒钟。我正在MacOSX上运行python 2.7。我正在使用ipython的%timeit命令对这两个函数计时

np.sqrt
的调用是一个Python函数调用,它正在破坏您的性能您正在计算标量浮点值的平方根,因此您应该使用C数学库中的
sqrt
函数。以下是您的代码的修改版本:

import numpy as np
cimport numpy as np
import cython
cimport cython

from libc.math cimport sqrt

DTYPE = np.float64
ctypedef np.float64_t DTYPE_t
@cython.boundscheck(False) # turn of bounds-checking for entire function
@cython.wraparound(False)
@cython.nonecheck(False)
def distMeasureCython(np.ndarray[DTYPE_t, ndim=2] points,
                      np.ndarray[DTYPE_t, ndim=1] loc):
    cdef unsigned int i
    cdef unsigned int L = points.shape[0]
    cdef np.ndarray[DTYPE_t, ndim=1] d = np.zeros(L)
    for i in xrange(0,L):
        d[i] = sqrt((points[i,0] - loc[0])**2 +
                    (points[i,1] - loc[1])**2 +
                    (points[i,2] - loc[2])**2)
    return d
下面演示了性能改进。您的原始代码位于模块
检查速度\u original
中,修改后的版本位于
检查速度
中:

In [11]: import check_speed_original

In [12]: import check_speed
设置测试数据:

In [13]: N = 10**6

In [14]: points = random.uniform(0,1,(N,3))

In [15]: loc = random.uniform(0,1,(3,))
原始版本在我的计算机上需要1.26秒:

In [16]: %timeit check_speed_original.distMeasureCython(points, loc)
1 loops, best of 3: 1.26 s per loop
修改后的版本需要4.47毫秒:

如果有人担心结果可能不同:

In [18]: d1 = check_speed.distMeasureCython(points, loc)

In [19]: d2 = check_speed_original.distMeasureCython(points, loc)

In [20]: np.all(d1 == d2)
Out[20]: True

如前所述,这是代码中的numpy.sqrt调用。但是,我认为不需要使用
cdefextern
,因为Cython已经提供了这些基本的C/C++库。(见附件)。所以你可以这样导入它:

    from libc.math cimport sqrt

只是为了消除开销。

我没有发现cython版本有任何明显的错误(我很惊讶它的速度如此之慢)。然而,你不会用cython打败一个正确矢量化的numpy表达式。Cython对于无法矢量化的操作是最好的(而且通常相当好)。另外,您可以使用
d=np.hypot(*d.T)
稍微加快numpy版本的速度。您是否运行了
cython-a your_code.pyx
并查看了
your_code.html
?这是一种检查cython生成的C代码的便捷方法,可以发现有多少代码已经翻译成C,还有多少代码仍然在Python级别工作。它很有效!非常感谢。我也得到了你上面提到的运行时间。谢谢你给我关于html的提示。这是我的第一个问题,既然问题解决了,我是否应该关闭它?谢谢你的帮助。
    from libc.math cimport sqrt