Python 优化Numba和Numpy函数
我试图让这段代码运行得更快,但我找不到更多的技巧来加速它Python 优化Numba和Numpy函数,python,optimization,numba,Python,Optimization,Numba,我试图让这段代码运行得更快,但我找不到更多的技巧来加速它 我得到了一个大约3微秒的运行时间,问题是我调用了这个函数几百万次,这个过程最终花费了很长时间。我有相同的Java实现(只有basic for循环),基本上,即使是大型训练数据(这是ANN)的计算也是瞬时的 有没有办法加快速度 我正在Windows10上运行Python 2.7、numba 0.43.1和numpy 1.16.3 x = True expected = 0.5 eligibility = np.array([0.1,0.1,
我得到了一个大约3微秒的运行时间,问题是我调用了这个函数几百万次,这个过程最终花费了很长时间。我有相同的Java实现(只有basic for循环),基本上,即使是大型训练数据(这是ANN)的计算也是瞬时的 有没有办法加快速度 我正在Windows10上运行Python 2.7、numba 0.43.1和numpy 1.16.3
x = True
expected = 0.5
eligibility = np.array([0.1,0.1,0.1])
positive_weight = np.array([0.2,0.2,0.2])
total_sq_grad_positive = np.array([0.1,0.1,0.1])
learning_rate = 1
@nb.njit(fastmath= True, cache = True)
def update_weight_from_post_post_jit(x, expected,eligibility,positive_weight,total_sq_grad_positive,learning_rate):
if x:
g = np.multiply(eligibility,(1-expected))
else:
g = np.negative(np.multiply(eligibility,expected))
gg = np.multiply(g,g)
total_sq_grad_positive = np.add(total_sq_grad_positive,gg)
#total_sq_grad_positive = np.where(divide_by_zero,total_sq_grad_positive, tsgp_temp)
temp = np.multiply(learning_rate, g)
temp2 = np.sqrt(total_sq_grad_positive)
#temp2 = np.where(temp2 == 0,1,temp2 )
temp2[temp2 == 0] = 1
temp = np.divide(temp,temp2)
positive_weight = np.add(positive_weight, temp)
return [positive_weight, total_sq_grad_positive]
编辑:看来@max9111是对的。不必要的临时阵列是开销的来源 对于函数的当前语义,似乎有两个无法避免的临时数组——返回值
[正权重,总平方梯度正]
。然而,我突然想到,您可能正计划使用此函数来更新这两个输入数组。如果是这样的话,通过一切到位,我们可以获得最大的加速。像这样:
import numba as nb
import numpy as np
x = True
expected = 0.5
eligibility = np.array([0.1,0.1,0.1])
positive_weight = np.array([0.2,0.2,0.2])
total_sq_grad_positive = np.array([0.1,0.1,0.1])
learning_rate = 1
@nb.njit(fastmath= True, cache = True)
def update_weight_from_post_post_jit(x, expected,eligibility,positive_weight,total_sq_grad_positive,learning_rate):
for i in range(eligibility.shape[0]):
if x:
g = eligibility[i] * (1-expected)
else:
g = -(eligibility[i] * expected)
gg = g * g
total_sq_grad_positive[i] = total_sq_grad_positive[i] + gg
temp = learning_rate * g
temp2 = np.sqrt(total_sq_grad_positive[i])
if temp2 == 0: temp2 = 1
temp = temp / temp2
positive_weight[i] = positive_weight[i] + temp
@nb.jit
def test(n, *args):
for i in range(n): update_weight_from_post_post_jit(*args)
@nb.jit
def test(n):
for i in range(n): update_weight_from_post_post_jit(x, expected,eligibility,positive_weight,total_sq_grad_positive,learning_rate)
如果您不想更新输入数组,可以从
positive_weight = positive_weight.copy()
total_sq_grad_positive = total_sq_grad_positive.copy()
并以原始代码的形式返回它们。这几乎没有那么快,但速度更快
我不确定是否可以优化为“瞬时”;我有点惊讶Java能做到这一点,因为这对我来说是一个相当复杂的函数,需要像
sqrt
这样耗时的操作
但是,您是否对调用此函数的函数使用了nb.jit
?像这样:
import numba as nb
import numpy as np
x = True
expected = 0.5
eligibility = np.array([0.1,0.1,0.1])
positive_weight = np.array([0.2,0.2,0.2])
total_sq_grad_positive = np.array([0.1,0.1,0.1])
learning_rate = 1
@nb.njit(fastmath= True, cache = True)
def update_weight_from_post_post_jit(x, expected,eligibility,positive_weight,total_sq_grad_positive,learning_rate):
for i in range(eligibility.shape[0]):
if x:
g = eligibility[i] * (1-expected)
else:
g = -(eligibility[i] * expected)
gg = g * g
total_sq_grad_positive[i] = total_sq_grad_positive[i] + gg
temp = learning_rate * g
temp2 = np.sqrt(total_sq_grad_positive[i])
if temp2 == 0: temp2 = 1
temp = temp / temp2
positive_weight[i] = positive_weight[i] + temp
@nb.jit
def test(n, *args):
for i in range(n): update_weight_from_post_post_jit(*args)
@nb.jit
def test(n):
for i in range(n): update_weight_from_post_post_jit(x, expected,eligibility,positive_weight,total_sq_grad_positive,learning_rate)
在我的计算机上,这将运行时间减少了一半,这是有意义的,因为Python函数调用的开销非常高。您好,谢谢您的回复。事实上,我确实从另一个jit函数调用了这个函数,但是,我在这里使用类,我基本上是从类中提取这个方法,所以在某种程度上我可以做到这一点,因为我不能从类中做到这一点。也许有更好的方法来避免这种情况?只是对compilerflags的一点评论。通常情况下,但不是在这种情况下(只有3个分区),零除法检查会降低性能,而不是必要的。您可以使用错误\ u model=“numpy”停用此功能。我实施了上述更改,我肯定可以看到速度加快,谢谢!在此基础上进行改进,是否有一些方法可以跳过中间部分,在类外调用函数并让jit类返回临时矩阵?基本上更新“self”变量,但在类之外?谢谢Omg我刚刚更新了关于python的知识,只要参数是对象(例如numpy数组),它们总是通过引用传递!“我有相同的Java实现(只有basic for循环)”对Python代码做同样的事情。每一个矢量化的操作都转换为一个带有不必要的临时数组的for循环。因为Numba(Python->LLVM-IR->LLVM-Backend)与Clang(C->LLVM-IR->LLVM-Backend)非常相似,所以要像C一样编写代码。