Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/python/339.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python 如何使用numba使numpy功能更快_Python_Numpy_Numba - Fatal编程技术网

Python 如何使用numba使numpy功能更快

Python 如何使用numba使numpy功能更快,python,numpy,numba,Python,Numpy,Numba,我听说过numba,它使执行时间更快。 我已经这样试过了 @jit def abs_diffs_signal(data): return np.sum(np.abs(np.diff(data,axis=0)),axis=0) subjects_data.agg([ 'min',abs_diffs_signal]) 它给了我以下警告 编译正在退回到启用循环提升的对象模式,因为函数“abs_diff_signal”未能进行类型推断,原因是:非精确类型pyobject 而且这也使得这一过程

我听说过numba,它使执行时间更快。 我已经这样试过了

@jit
def abs_diffs_signal(data):
    return np.sum(np.abs(np.diff(data,axis=0)),axis=0)
subjects_data.agg([ 'min',abs_diffs_signal])
它给了我以下警告
编译正在退回到启用循环提升的对象模式,因为函数“abs_diff_signal”未能进行类型推断,原因是:非精确类型pyobject

而且这也使得这一过程变得缓慢

那么我该怎么做才能让它更快呢

我是这样调用上述函数的

@jit
def abs_diffs_signal(data):
    return np.sum(np.abs(np.diff(data,axis=0)),axis=0)
subjects_data.agg([ 'min',abs_diffs_signal])

其中作为受试者的数据是groupby函数,完整的错误消息:

In [155]: abs_diffs_signal(np.arange(12).reshape(3,4))
<ipython-input-154-3b4e0175a309>:1: NumbaWarning: 
Compilation is falling back to object mode WITH looplifting enabled because Function "abs_diffs_signal" failed type inference due to: No implementation of function Function(<function diff at 0x7fcf048dc488>) found for signature:
 
 >>> diff(array(int64, 2d, C), axis=Literal[int](0))
 
There are 2 candidate implementations:
  - Of which 2 did not match due to:
  Overload in function 'np_diff_impl': File: numba/np/arraymath.py: Line 3374.
    With argument(s): '(array(int64, 2d, C), axis=Literal[int](0))':
   Rejected as the implementation raised a specific error:
     TypeError: np_diff_impl() got an unexpected keyword argument 'axis'
  raised from /usr/local/lib/python3.6/dist-packages/numba/core/typing/templates.py:675

During: resolving callee type: Function(<function diff at 0x7fcf048dc488>)
During: typing of call at <ipython-input-154-3b4e0175a309> (3)


File "<ipython-input-154-3b4e0175a309>", line 3:
def abs_diffs_signal(data):
    return np.sum(np.abs(np.diff(data,axis=0)),axis=0)
    ^

  @numba.jit
/usr/local/lib/python3.6/dist-packages/numba/core/object_mode_passes.py:178: NumbaWarning: Function "abs_diffs_signal" was compiled in object mode without forceobj=True.

File "<ipython-input-154-3b4e0175a309>", line 2:
@numba.jit
def abs_diffs_signal(data):
^

  state.func_ir.loc))
/usr/local/lib/python3.6/dist-packages/numba/core/object_mode_passes.py:188: NumbaDeprecationWarning: 
Fall-back from the nopython compilation path to the object mode compilation path has been detected, this is deprecated behaviour.

For more information visit http://numba.pydata.org/numba-doc/latest/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit

File "<ipython-input-154-3b4e0175a309>", line 2:
@numba.jit
def abs_diffs_signal(data):
^

  state.func_ir.loc))
Out[155]: array([8, 8, 8, 8])
[155]中的
abs\u diffs\u信号(np.arange(12).整形(3,4))
:1:麻木:
编译正在退回到启用循环提升的对象模式,因为函数“abs_diff_signal”未能进行类型推断,原因是:未找到签名函数()的实现:
>>>diff(数组(int64,2d,C),axis=Literal[int](0))
有两种候选实现:
-其中2个不匹配,原因是:
函数“np_diff_impl”中的重载:文件:numba/np/arraymath.py:第3374行。
带参数:“(数组(int64,2d,C),axis=Literal[int](0)):
已拒绝,因为实现引发了特定错误:
TypeError:np_diff_impl()获得意外的关键字参数“axis”
源于/usr/local/lib/python3.6/dist-packages/numba/core/typing/templates.py:675
期间:解析被调用方类型:函数()
期间:在(3)处键入呼叫
文件“”,第3行:
def abs_差异信号(数据):
返回np.sum(np.abs(np.diff(数据,轴=0)),轴=0)
^
@麻木
/usr/local/lib/python3.6/dist-packages/numba/core/object\u mode\u passs.py:178:NumbaWarning:Function“abs\u diffs\u signal”是在对象模式下编译的,没有forceobj=True。
文件“”,第2行:
@麻木
def abs_差异信号(数据):
^
州职能(见附件)
/usr/local/lib/python3.6/dist-packages/numba/core/object\u mode\u passes.py:188:numbadepreaction警告:
检测到从nopython编译路径返回到对象模式编译路径,这是不推荐的行为。
欲了解更多信息,请访问http://numba.pydata.org/numba-doc/latest/reference/deprecation.html#deprecation-使用jit时对象模式的回退行为
文件“”,第2行:
@麻木
def abs_差异信号(数据):
^
州职能(见附件)
Out[155]:数组([8,8,8,8])
我不是
numba
方面的专家,但它的
np.diff
实现似乎不适用于
参数。有记录吗

根据
numba
docs

numpy.diff()(仅前两个参数)

def diff(a,n=1,轴=-1,前置=np.\u新值,追加=np.\u新值):

因此,支持
a
n
,但不支持
diff
是一个相当简单的方法。对于1d,它只是
arr[1:]-arr[:-1]
。因此,您应该能够用扩展来替换它


一般来说,您不能仅仅围绕
numpy
函数抛出一个
jit
,然后期望它工作,更不用说工作得更快了。有时确实如此,但在很多情况下,您需要同时注意错误消息和文档。这不是一个简单的工具

如果您想超越已经矢量化的numpy函数,最好将代码简化一点(大型数组上的BLAS调用除外)。 所有numpy函数本身都非常快,但在本例中,numpy分配临时数组,这非常昂贵,如果增加数组大小,还会导致缓存未命中

示例2D输入

import numpy as np
import numba as nb

def abs_diffs_signal_np(data):
    return np.sum(np.abs(np.diff(data,axis=0)),axis=0)

@nb.njit()
def abs_diffs_signal(data):
    res=np.zeros(data.shape[1],dtype=data.dtype)
    for i in range(data.shape[0]-1):
        for j in range(data.shape[1]):
            res[j]+=np.abs(data[i+1,j]-data[i,j])
    return res
计时

data=np.random.rand(1_000_000,20)
np.allclose(abs_diffs_signal_np(data),abs_diffs_signal(data))
#True
%timeit abs_diffs_signal_np(data)
#402 ms ± 1.06 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit abs_diffs_signal(data)
#12.2 ms ± 20.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

如果您的代码已经作为矢量化操作编写,Numba将不会帮助您。你可以编写一个循环来做同样的事情,numba可以编译它来更快地工作?