Python 如何编译numba jit';具有可变输入类型的ed函数?

Python 如何编译numba jit';具有可变输入类型的ed函数?,python,random,signature,optional-parameters,numba,Python,Random,Signature,Optional Parameters,Numba,假设我有一个函数,可以接受int或None类型作为输入参数 import numba as nb import numpy as np jitkw = {"nopython": True, "nogil": True, "error_model": "numpy", "fastmath": True} @nb.jit("f8(i8)", **jitkw) def get_random(seed=None): np.random.seed(None) out = np.ran

假设我有一个函数,可以接受
int
None
类型作为输入参数

import numba as nb
import numpy as np

jitkw = {"nopython": True, "nogil": True, "error_model": "numpy", "fastmath": True}


@nb.jit("f8(i8)", **jitkw)
def get_random(seed=None):
    np.random.seed(None)
    out = np.random.normal()
    return out
我希望函数只返回一个正态分布的随机数。如果我想要可复制的结果,seed应该是
int

get_random(42)
>>> 0.4967141530112327
get_random(42)
>>> 0.4967141530112327
get_random(42)
>>> 0.4967141530112327
如果我想要随机数,
seed
应该保留为
None
。但是,如果我不传递参数(因此seed默认为
None
)或显式传递
seed=None
,则numba会引发
TypeError

get_random()
>>> TypeError: No matching definition for argument type(s) omitted(default=None)
get_random(None)
>>> TypeError: No matching definition for argument type(s) omitted(default=None)
在这种情况下,我如何编写函数,仍然声明签名并使用
nopython
模式


我的numba版本是0.43.1

第一个问题是nopython模式下的numba只接受(从0.43.1版开始)

因此,很遗憾,您无法传入
None


第二个问题是(据我所知)没有一个“单一”签名告诉numba如何处理缺少的值,但是您可以使用两个签名(是的,非常详细):

只是简单解释一下signaure的两个部分:

  • nb.types.float64(nb.types.misc.impled(None))
    告诉numba如果省略了参数,则使用
    None
    作为默认类型
  • nb.types.float64(nb.types.int64)
    是需要整数的签名

就我个人而言,我不会指定签名,而只是让numba算出。在numba中,显式签名很少值得使用,而且更常见的情况是,它们会导致代码速度变慢、灵活性变差。

使用numba,如果您包含了您的numba版本,它对将来的参考和回答者总是有帮助的。这是因为支持的功能在不同版本之间可能会有很大的差异。我在定义numpy数组签名时经历了巨大的加速:例如nb.float64[:]或nb.int64[:,:]等,但对于正则变量来说没有太多。@MottTheTuple这很有趣,因为通常情况下,除了第一次调用之外,numba中的类型不会得到加速,因为如果不定义类型,函数将使用给定的参数编译,而使用给定的类型,函数将在定义时编译。签名的目的本质上是“对编译器选择的类型进行细粒度控制”(来自)。如果你有一个例子,其中的函数速度是真的非常不同,我会非常感兴趣的一个例子(例如一个要点)。我看到这是在我的线性回归函数。我发送一些极值,然后计算斜率、截距和残差,所有这些都使用较小的numpy函数:sum、mean、max等。我删除残差最低的行。这都是在while语句中完成的,直到我达到最小行#.1000循环-python时间:0005,使用nb.jit(nopython=True):.0019,nb.jit(nb.types.float64[:,:](nb.types.float64[,:],nb.int32),nopython=True)时间:1.5624e-05。这是我脑子里想不到的,也许不是“大规模的”,但我很幸运地在funcs中添加了sigs,在funcs中,nopython模式本身性能很差。@mott在计时过程中,您是否排除了第一个调用(包括编译成本)?我对numba的内部结构略知一二,很少有类型化函数更快的情况,大多数情况下它们会稍微慢一点(如果数组是连续的,则也会键入)。不过,如果您能创建一个gist(GitHub)或其他我可以查看的东西,我将不胜感激。我返回到我的计时,发现我没有排除第一次呼叫。我测试了两个Nopyton函数,两个都以相同的速度运行。所以你的前提可能是正确的,在大多数情况下签名是不必要的。。我可能在我以前使用过的函数中有特殊的输入类型,这可能会导致加速,如果我发现我将向您发送一个要点链接的话。谢谢你们的友好讨论。
import numba as nb
import numpy as np

jitkw = {"nopython": True, "nogil": True, "error_model": "numpy", "fastmath": True}

@nb.jit(
    [nb.types.float64(nb.types.misc.Omitted(None)), 
     nb.types.float64(nb.types.int64)], 
    **jitkw)
def get_random(seed=None):
    return np.random.normal()