Python Can';t在numba中创建多维数组

Python Can';t在numba中创建多维数组,python,numpy,numba,Python,Numpy,Numba,我正在努力做到以下几点: import numpy as np import numba as nb @nb.njit def test(x): return np.array([[x, x], [x, x]]) test(np.array([5,5])) 但这与 TypingError: Failed in nopython mode pipeline (step: nopython frontend) Invalid use of F

我正在努力做到以下几点:

import numpy as np
import numba as nb

@nb.njit
def test(x):
    return np.array([[x, x], 
                     [x, x]])

test(np.array([5,5]))
但这与

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<built-in function array>) with argument(s) of type(s): (list(list(array(int64, 1d, C))))
 * parameterized
In definition 0:
    TypingError: array(int64, 1d, C) not allowed in a homogeneous sequence
    raised from /home/bellinger/anaconda3/lib/python3.7/site-packages/numba/typing/npydecl.py:460
In definition 1:
    TypingError: array(int64, 1d, C) not allowed in a homogeneous sequence
    raised from /home/bellinger/anaconda3/lib/python3.7/site-packages/numba/typing/npydecl.py:460
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: resolving callee type: Function(<built-in function array>)
[2] During: typing of call at <ipython-input-108-17a4ebeac76c> (4)


File "<ipython-input-108-17a4ebeac76c>", line 4:
def test(x):
    <source elided>
    return np.array([[x, x], 
                     [x, x]])
TypingError:在nopython模式管道中失败(步骤:nopython前端)
函数()的类型参数使用无效:(列表(列表(数组(int64,1d,C)))
*参数化
在定义0中:
TypingError:齐次序列中不允许使用数组(int64,1d,C)
来自/home/belinger/anaconda3/lib/python3.7/site packages/numba/typing/npydecl.py:460
在定义1中:
TypingError:齐次序列中不允许使用数组(int64,1d,C)
来自/home/belinger/anaconda3/lib/python3.7/site packages/numba/typing/npydecl.py:460
此错误通常由传递指定函数不支持的类型的参数引起。
[1] 期间:解析被调用方类型:函数()
[2] 期间:在(4)处键入呼叫
文件“”,第4行:
def测试(x):
返回np.array([[x,x],
[x,x]]
工作案例:

In [116]: test(10)                                                                                   
Out[116]: array([10, 10])
尝试列表-工作正常,但会出现警告:

In [117]: test([1,2])                                                                                
/usr/local/lib/python3.6/dist-packages/numba/core/ir_utils.py:2031: NumbaPendingDeprecationWarning: 
Encountered the use of a type that is scheduled for deprecation: type 'reflected list' found for argument 'x' of function 'test'.

For more information visit http://numba.pydata.org/numba-doc/latest/reference/deprecation.html#deprecation-of-reflection-for-list-and-set-types

File "<ipython-input-115-c82b85bdc507>", line 2:
@nb.njit
def test(x):
^

  warnings.warn(NumbaPendingDeprecationWarning(msg, loc=loc))
Out[117]: 
array([[1, 2],
       [1, 2]])

您要求
numba
版本的
np.array
实现
numpy
版本的所有细微差别。在本例中,从嵌套数组列表生成数组。在这种情况下,
numpy
结果是一个(2,2,2)数组。不要惊讶
numba
没有同样的灵活性。
In [117]: test([1,2])                                                                                
/usr/local/lib/python3.6/dist-packages/numba/core/ir_utils.py:2031: NumbaPendingDeprecationWarning: 
Encountered the use of a type that is scheduled for deprecation: type 'reflected list' found for argument 'x' of function 'test'.

For more information visit http://numba.pydata.org/numba-doc/latest/reference/deprecation.html#deprecation-of-reflection-for-list-and-set-types

File "<ipython-input-115-c82b85bdc507>", line 2:
@nb.njit
def test(x):
^

  warnings.warn(NumbaPendingDeprecationWarning(msg, loc=loc))
Out[117]: 
array([[1, 2],
       [1, 2]])
In [118]: test(np.array([1,2]))                                                                      
---------------------------------------------------------------------------
TypingError                               Traceback (most recent call last)
<ipython-input-118-521455fb3f7f> in <module>
----> 1 test(np.array([1,2]))

/usr/local/lib/python3.6/dist-packages/numba/core/dispatcher.py in _compile_for_args(self, *args, **kws)
    413                 e.patch_message(msg)
    414 
--> 415             error_rewrite(e, 'typing')
    416         except errors.UnsupportedError as e:
    417             # Something unsupported is present in the user code, add help info

/usr/local/lib/python3.6/dist-packages/numba/core/dispatcher.py in error_rewrite(e, issue_type)
    356                 raise e
    357             else:
--> 358                 reraise(type(e), e, None)
    359 
    360         argtypes = []

/usr/local/lib/python3.6/dist-packages/numba/core/utils.py in reraise(tp, value, tb)
     78         value = tp()
     79     if value.__traceback__ is not tb:
---> 80         raise value.with_traceback(tb)
     81     raise value
     82 

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function array>) found for signature:
 
 >>> array(list(array(int64, 1d, C)))
 
There are 2 candidate implementations:
  - Of which 2 did not match due to:
  Overload in function 'array': File: numba/core/typing/npydecl.py: Line 504.
    With argument(s): '(list(array(int64, 1d, C)))':
   Rejected as the implementation raised a specific error:
     TypingError: array(int64, 1d, C) not allowed in a homogeneous sequence
  raised from /usr/local/lib/python3.6/dist-packages/numba/core/typing/npydecl.py:471

During: resolving callee type: Function(<built-in function array>)
During: typing of call at <ipython-input-115-c82b85bdc507> (3)


File "<ipython-input-115-c82b85bdc507>", line 3:
def test(x):
    return np.array([x,x])
    ^
In [143]: @nb.njit() 
     ...: def test(x): 
     ...:     temp = np.stack((x,x,x,x))     # tuple is important
     ...:     return temp.reshape((2,2)+(x.shape)) 

In [147]: test(np.array([1,2]))                                                                      
Out[147]: 
array([[[1, 2],
        [1, 2]],

       [[1, 2],
        [1, 2]]])