Python 使用jit正确注释numba函数
我从这段代码开始计算一个简单的矩阵乘法。它在我的机器上以大约7.85秒的速度运行%timeit 为了加快速度,我尝试了cython,它将时间缩短到了0.4s。我还想尝试使用numba jit编译器,看看是否可以用更少的精力获得类似的加速。但是添加@jit注释似乎给出了完全相同的计时~7.8秒。我知道它不知道calculate_z_numpy调用的类型,但我不确定我能做些什么来强制它。有什么想法吗Python 使用jit正确注释numba函数,python,performance,numpy,numba,Python,Performance,Numpy,Numba,我从这段代码开始计算一个简单的矩阵乘法。它在我的机器上以大约7.85秒的速度运行%timeit 为了加快速度,我尝试了cython,它将时间缩短到了0.4s。我还想尝试使用numba jit编译器,看看是否可以用更少的精力获得类似的加速。但是添加@jit注释似乎给出了完全相同的计时~7.8秒。我知道它不知道calculate_z_numpy调用的类型,但我不确定我能做些什么来强制它。有什么想法吗 from numba import jit import numpy as np @jit('f8
from numba import jit
import numpy as np
@jit('f8(c8[:],c8[:],uint)')
def calculate_z_numpy(q, z, maxiter):
"""use vector operations to update all zs and qs to create new output array"""
output = np.resize(np.array(0, dtype=np.int32), q.shape)
for iteration in range(maxiter):
z = z*z + q
done = np.greater(abs(z), 2.0)
q = np.where(done, 0+0j, q)
z = np.where(done, 0+0j, z)
output = np.where(done, iteration, output)
return output
def calc_test():
w = h = 1000
maxiter = 1000
# make a list of x and y values which will represent q
# xx and yy are the co-ordinates, for the default configuration they'll look like:
# if we have a 1000x1000 plot
# xx = [-2.13, -2.1242,-2.1184000000000003, ..., 0.7526000000000064, 0.7584000000000064, 0.7642000000000064]
# yy = [1.3, 1.2948, 1.2895999999999999, ..., -1.2844000000000058, -1.2896000000000059, -1.294800000000006]
x1, x2, y1, y2 = -2.13, 0.77, -1.3, 1.3
x_step = (float(x2 - x1) / float(w)) * 2
y_step = (float(y1 - y2) / float(h)) * 2
y = np.arange(y2,y1-y_step,y_step,dtype=np.complex)
x = np.arange(x1,x2,x_step)
q1 = np.empty(y.shape[0],dtype=np.complex)
q1.real = x
q1.imag = y
# Transpose y
x_y_square_matrix = x+y[:, np.newaxis] # it is np.complex128
# convert square matrix to a flatted vector using ravel
q2 = np.ravel(x_y_square_matrix)
# create z as a 0+0j array of the same length as q
# note that it defaults to reals (float64) unless told otherwise
z = np.zeros(q2.shape, np.complex128)
output = calculate_z_numpy(q2, z, maxiter)
print(output)
calc_test()
我在别人的帮助下想出了如何做到这一点
@jit('i4[:](c16[:],c16[:],i4,i4[:])',nopython=True)
def calculate_z_numpy(q, z, maxiter,output):
"""use vector operations to update all zs and qs to create new output array"""
for iteration in range(maxiter):
for i in range(len(z)):
z[i] = z[i] + q[i]
if z[i] > 2:
output[i] = iteration
z[i] = 0+0j
q[i] = 0+0j
return output
我学到的是使用numpy数据结构作为输入进行键入,但在内部使用类似c的范例进行循环
它以402ms的速度运行,比cython代码0.45s快一点,因此对于显式重写循环的最小工作量,我们有一个比Cjust快的python版本