Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/python/278.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python njit numba函数的高级索引替代方案_Python_Numpy_Numba - Fatal编程技术网

Python njit numba函数的高级索引替代方案

Python njit numba函数的高级索引替代方案,python,numpy,numba,Python,Numpy,Numba,给出以下最小可复制示例: import numpy as np from numba import jit # variable number of dimensions n_t = 8 # q is just a partition of n q_ddl = 2 n_ddl = 3 np.random.seed(42) df = np.random.rand(q_ddl*n_t,q_ddl*n_t) # index array # ddl_nl is a set of np.arange

给出以下最小可复制示例:

import numpy as np
from numba import jit

# variable number of dimensions
n_t = 8
# q is just a partition of n
q_ddl = 2
n_ddl = 3

np.random.seed(42)
df = np.random.rand(q_ddl*n_t,q_ddl*n_t)

# index array
# ddl_nl is a set of np.arange(n_ddl), ex: [0,1] ; [0,2] or even [0] ...
ddl_nl = np.array([0,1])
ij = np.asarray(np.meshgrid(ddl_nl,ddl_nl,indexing='ij'))

@jit(nopython=True)
def foo(df,ij):
    out = np.zeros((n_t,n_ddl,n_ddl))
    for i in range(0,n_t):     
        d_i = np.zeros((n_ddl,n_ddl))
        # (q_ddl,q_ddl) non zero values into (n_ddl,n_ddl) shape
        d_i[ij[0], ij[1]] = df[i::n_t,i::n_t]
        # to check possible solutions
        out[i,...] = d_i
    return out


out_foo = foo(df,ij)
禁用
@jit(nopython=True)
时,函数
foo
工作正常,但启用时会引发以下错误:

TypeError: unsupported array index type array(int64, 2d, C) in UniTuple(array(int64, 2d, C) x 2)
这是在广播操作期间发生的
d_i[ij[0],ij[1]]=df[i::n_t,i::n_t]
。然后,我尝试用类似于
d_I[ij[0].ravel(),ij[1].ravel()]=df[I::n_t,I::n_t].ravel()的东西展平二维索引数组
ij
,这给了我相同的输出,但现在又出现了另一个错误:

NotImplementedError: only one advanced index supported
因此,我最终尝试通过使用经典的2嵌套
for
循环结构来回避这个问题:

tmp = df[i::n_t,i::n_t]
for k,r in enumerate(ddl_nl):
    for l,c in enumerate(ddl_nl):
        d_i[r,c] = tmp[k,l]
它在启用装饰器的情况下工作,并给出预期的结果


但是我不能停止思考,对于这个numpy 2d阵列广播操作,是否有任何与numba兼容的替代方案,我在这里错过了?任何帮助都将不胜感激。

检查您的一些价值观:

In [446]: ddl_nl = np.array([0,1]) 
     ...: ij = np.asarray(np.meshgrid(ddl_nl,ddl_nl,indexing='ij')) 
     ...:                                                                                      
In [447]: ij                                                                                   
Out[447]: 
array([[[0, 0],
        [1, 1]],

       [[0, 1],
        [0, 1]]])
In [448]: n_t = 8 
     ...: q_ddl = 2 
     ...: n_ddl = 3                                                                            
In [449]: d_i = np.zeros((n_ddl,n_ddl))                                                        
In [450]: d_i                                                                                  
Out[450]: 
array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]])
In [451]: d_i[ij[0], ij[1]]                                                                    
Out[451]: 
array([[0., 0.],
       [0., 0.]])
请尝试更具诊断性的检查:

In [452]: d_i = np.arange(9).reshape(3,3)                                                      
In [453]: d_i[ij[0], ij[1]]                                                                    
Out[453]: 
array([[0, 1],
       [3, 4]])
In [454]: d_i[:2,:2]                                                                           
Out[454]: 
array([[0, 1],
       [3, 4]])
当基本切片可以工作时,为什么要使用高级索引

我没有尝试过使用
numba
,但它可能有更好的工作机会。也就是说,枚举循环可能同样快。我没有足够的经验来肯定地说

===

显然,您执行了
numpy
操作,而
numba
不支持:

In [456]: numba.__version__                                                                    
Out[456]: '0.43.0'
In [457]: @numba.jit 
     ...: def foo(arr): 
     ...:     return arr[[1,2,3],[1,2,3]] 
     ...:                                                                                      
In [458]: foo(np.eye(4))                                                                       
Out[458]: array([1., 1., 1.])
In [459]: @numba.njit 
     ...: def foo(arr): 
     ...:     return arr[[1,2,3],[1,2,3]] 
     ...:                                                                                      
In [460]: foo(np.eye(4))    
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<built-in function getitem>) with argument(s) of type(s): (array(float64, 2d, C), tuple(list(int64) x 2))
避免花哨的索引 还要避免使用全局变量(它们是在编译时硬编码的),并使代码尽可能简单(简单意味着只有一个dew循环,if/else,…)。如果
ddl\u nl
数组实际上只使用np.arange构建,那么甚至根本不需要这个数组

示例

import numpy as np
from numba import jit

@jit(nopython=True)
def foo_nb(df,n_ddl,n_t,ddl_nl):
    out = np.zeros((n_t,n_ddl,n_ddl))
    for i in range(0,n_t):
        for ii in range(ddl_nl.shape[0]):
            ind_1=ddl_nl[ii]
            for jj in range(ddl_nl.shape[0]):
                ind_2=ddl_nl[jj]
                out[i,ind_1,ind_2] = df[i+ii*n_t,i+jj*n_t]
    return out
计时

#Testing and compilation
A=foo(df,ij)
B=foo_nb(df,n_ddl,n_t,ddl_nl)
print(np.allclose(A,B))
#True
%timeit foo(df,ij)
#16.8 µs ± 107 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit foo_nb(df,n_ddl,n_t,ddl_nl)
#674 ns ± 2.56 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

定义“numba friendly”:)网格是否也在您的实际功能中?维度的数量总是相同的吗?在本例中,您根本不需要像网格网格或奇特的索引这样的东西。即使它可以工作,它也会比简单的嵌套循环慢。使用像n_t这样的gloabls也是不推荐的(你不能在没有推荐的情况下更改它们)维度的数量并不总是相同的。。。但是q只是n的一个划分。我说的是友好的麻木,我指的是兼容的麻木。你们2是对的,谢谢你们的回答,因为循环和奇特的方法一样快(甚至更快)。在这种情况下,我想我会走这条路。
#Testing and compilation
A=foo(df,ij)
B=foo_nb(df,n_ddl,n_t,ddl_nl)
print(np.allclose(A,B))
#True
%timeit foo(df,ij)
#16.8 µs ± 107 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit foo_nb(df,n_ddl,n_t,ddl_nl)
#674 ns ± 2.56 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)