Python 使用带有dicts和tuple输入的类

Python 使用带有dicts和tuple输入的类,python,dictionary,jit,numba,Python,Dictionary,Jit,Numba,我正在优化一些代码,这些代码主要包含在一个python类中。它对python对象的操作非常少,所以我认为使用Nuba会是一个很好的匹配,但在创建对象的过程中,我有大量需要的参数,而且我不认为我完全理解Nuba相对较新的dict支持()。我拥有的参数都是单浮点数或整数,它们被传递到对象中,存储,然后在整个代码运行过程中使用,如下所示: import numpy as np from numba import jitclass, float64 spec = [ ('p', dict),

我正在优化一些代码,这些代码主要包含在一个python类中。它对python对象的操作非常少,所以我认为使用Nuba会是一个很好的匹配,但在创建对象的过程中,我有大量需要的参数,而且我不认为我完全理解Nuba相对较新的dict支持()。我拥有的参数都是单浮点数或整数,它们被传递到对象中,存储,然后在整个代码运行过程中使用,如下所示:

import numpy as np
from numba import jitclass, float64

spec = [
    ('p', dict),
    ('shape', tuple),               # the shape of the array
    ('array', float64[:,:]),          # an array field
]

params_default = {
    par_1 = 1,
    par_2 = 0.5
    }

@jitclass(spec)
class myObj:
    def __init__(self,params = params_default,shape = (100,100)):
        self.p = params
        self.shape = shape
        self.array = self.p['par_2']*np.ones(shape)

    def inc_arr(self):
        self.array += self.p['par_1']*np.ones(shape)
我想我不明白Numba需要什么。如果我想使用nopython模式使用Numba来优化它,我需要向jitclass装饰器传递一个规范吗?如何定义字典的规范?我也需要声明形状元组吗?我已经查看了我在jitclass decorator上找到的文档以及dict numba文档,我不知道该怎么做。运行上述代码时,出现以下错误:

TypeError: spec values should be Numba type instances, got <class 'dict'>
TypeError:规范值应为Numba类型实例,已获取
我是否需要在规范中包含dict元素?文档中不清楚正确的语法是什么


或者,有没有办法让Numba推断输入类型?

spec
需要由特定于Numba的类型组成,而不是python类型! 因此规范中的
tuple
dict
必须是类型的numba类型(并且仅允许使用同质dict)

因此,要么在jitted函数中指定
params_default
dict,如图所示,要么显式键入numba dict

对于这种情况,我将采用后一种方法:

import numpy as np
from numba import jitclass, float64

# Explicitly define the types of the key and value:
params_default = nb.typed.Dict.empty(
    key_type=nb.typeof('par_1'),
    value_type=nb.typeof(0.5)
)

# assign your default values
params_default['par_1'] = 1.  # Same type required, thus setting to float
params_default['par_2'] = .5

spec = [
    ('p', nb.typeof(params_default)),
    ('shape', nb.typeof((100, 100))),               # the shape of the array
    ('array', float64[:, :]),          # an array field
]

@jitclass(spec)
class myObj:
    def __init__(self, params=params_default, shape=(100, 100)):
        self.p = params
        self.shape = shape
        self.array = self.p['par_2'] * np.ones(shape)

    def inc_arr(self):
        self.array += self.p['par_1'] * np.ones(shape)
正如已经指出的那样:事实上,dict是同源类型的。因此,所有键/值必须是相同的类型。因此,在同一个dict中存储
int
float
将不起作用