Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/python/286.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python 如何在numba.njit中返回布尔数组?_Python_Arrays_Numpy_Boolean_Numba - Fatal编程技术网

Python 如何在numba.njit中返回布尔数组?

Python 如何在numba.njit中返回布尔数组?,python,arrays,numpy,boolean,numba,Python,Arrays,Numpy,Boolean,Numba,如果我这样测试: import numpy as np from numba import njit, float64 from numba.experimental import jitclass @njit(fastmath=True) def compare(values1, values2): shape = values1.shape[0] res = np.zeros(shape, dtype=bool) for i in range(shape)

如果我这样测试:

import numpy as np
from numba import njit, float64
from numba.experimental import jitclass

@njit(fastmath=True)
def compare(values1, values2):
    shape = values1.shape[0]
    res = np.zeros(shape, dtype=bool)
    
    for i in range(shape):
        res[i] = x[i] > y[i]
    
    return res

spce = [("x", float64[:]),
        ("y", float64[:]),
        ("z", float64[:]),]
        
@jitclass(spce)
class Math:
    
    def __init__(self, x, y, z):
        self.x = x
        self.y = y
        self.z = z
    
    def calculate(self):
        i = compare(self.x, self.y)
        return self.z[i]
它将返回:

x = np.random.rand(10)
y = np.random.rand(10)
compare(x, y)
实际上,在numpy中输出的只是z[x>y],但是我如何在njit&jitclass中使用呢

我需要它们两个来加速我的其他代码


如果比较函数可以返回布尔数组,则问题应该得到解决。

您必须使用Numba的特殊
bool\ucode>类型:

将numpy导入为np
从numba.types导入bool、int、float32
@njit(bool_u[:,:](float32[:,:,:],float32[:,:,:],int_307;)
def测试(im1,im2,j_δ=1):
差值=((im1-im2)**2)。和(轴=2)/3
掩码=np.zeros_like(diff,bool_)#1.0
返回掩码
如果您用
bool
或甚至
np.bool
替换
bool\ucode>,您将得到一个编译错误

Traceback (most recent call last):

  File "<ipython-input-25-586dc5d173c7>", line 3, in <module>
    compare(x, y)

  File "C:\Users\Option00\Anaconda3\envs\bot\lib\site-packages\numba\core\dispatcher.py", line 415, in _compile_for_args
    error_rewrite(e, 'typing')

  File "C:\Users\Option00\Anaconda3\envs\bot\lib\site-packages\numba\core\dispatcher.py", line 358, in error_rewrite
    reraise(type(e), e, None)

  File "C:\Users\Option00\Anaconda3\envs\bot\lib\site-packages\numba\core\utils.py", line 80, in reraise
    raise value.with_traceback(tb)

TypingError: No implementation of function Function(<built-in function zeros>) found for signature:
 
zeros(int64, dtype=Function(<class 'bool'>))
 
There are 2 candidate implementations:
  - Of which 2 did not match due to:
  Overload of function 'zeros': File: numba\core\typing\npydecl.py: Line 504.
    With argument(s): '(int64, dtype=Function(<class 'bool'>))':
   No match.

During: resolving callee type: Function(<built-in function zeros>)
During: typing of call at <ipython-input-24-69a4f907fb89> (4)
x = np.random.rand(10)
y = np.random.rand(10)
z = np.random.rand(10)

m = Math(x, y, z)
m.calculate()