Python 使用Numba加速以下代码

Python 使用Numba加速以下代码,python,numpy,numba,Python,Numpy,Numba,我正在尝试使用Numba来加速一段代码。代码很简单,基本上是一个在numpy数组上进行简单计算的循环 import numpy as np import time from numba import jit, double def MinimizeSquareDiffBudget(x, budget): if (budget > np.sum(x)): return x n = np.size(x,0) j = 1 i = 0 y

我正在尝试使用Numba来加速一段代码。代码很简单,基本上是一个在numpy数组上进行简单计算的循环

import numpy as np
import time
from numba import jit, double

def MinimizeSquareDiffBudget(x, budget):
    if (budget > np.sum(x)):
        return x
    n = np.size(x,0)
    j = 1
    i = 0
    y = np.zeros((n, 1))
    while (budget > 0):
        while (x[i] == x[j]) and (j < n-1):
            j += 1
        i = j - 1
        if (np.std(x)<1e-10):
            to_give = budget/n
            y += to_give
            x= x- to_give
            break
        to_give = min(budget, (x[0] - x[j])*j)
        y[0:j] += to_give/j
        x[0:j]=x[0:j]-to_give/j
        budget = budget - to_give
        j = 1
    return y
然而,时间大致相同,而我预计Numba会快得多

测试代码:

budget = 335.0

x = np.random.uniform(0,1,(1000,1))
x.sort(axis=0)
x = x[::-1]
t = time.process_time()
y = MinimizeSquareDiffBudget(x, budget)
print(time.process_time()-t)

x = np.random.uniform(0,1,(1000,1))
x.sort(axis=0)
x = x[::-1]

t = time.process_time()
y = fastMinimizeSquareDiffBudget(x, budget)
print(time.process_time()-t)
直接实现需要0.28秒,使用Numba优化代码需要0.45秒。用C编写的相同代码所需时间少于0.001秒


有什么想法吗?

当您只对jit函数的一次执行计时时,您会看到运行时以及Numba对代码进行jit所需的时间。如果您再次运行代码,您将看到实际的加速,因为Numba使用编译函数的内存缓存,所以您只需为每个参数类型支付一次编译时间

在我使用python 3.6和numba 0.31.0的机器上,纯python函数需要0.32秒。第一次调用fastMinimizeSquareDiffBudget需要0.57秒,但第二次需要0.31秒

现在,您没有看到巨大的速度提升的原因是,您有一个函数,Numba无法在nopython模式下编译,因此它会返回到慢得多的对象模式。如果将nopython=True传递给jit方法,您将能够看到它无法编译的地方。我看到的两个问题是,您应该使用x.shape[0]而不是np.sizex,0,并且不能以您现在的方式使用min

budget = 335.0

x = np.random.uniform(0,1,(1000,1))
x.sort(axis=0)
x = x[::-1]
t = time.process_time()
y = MinimizeSquareDiffBudget(x, budget)
print(time.process_time()-t)

x = np.random.uniform(0,1,(1000,1))
x.sort(axis=0)
x = x[::-1]

t = time.process_time()
y = fastMinimizeSquareDiffBudget(x, budget)
print(time.process_time()-t)