如何加速python';带Numba的词典

如何加速python';带Numba的词典,python,arrays,python-3.x,dictionary,numba,Python,Arrays,Python 3.x,Dictionary,Numba,我需要在布尔值数组中存储几个单元格。起初我使用numpy,但当数组开始占用大量内存时,我想到了在字典中存储非零元素,并将元组作为键(因为它是可哈希类型)。对于emaxple: {(0,0,0):True,(1,2,3):True}(这是“3D数组”中的两个单元格,索引为0,0,0和1,2,3,但维数事先未知,并在运行算法时定义)。 这很有帮助,因为非零单元只填充了阵列的一小部分 为了从这个dict中写入和获取值,我需要使用循环: def fill_cells(indices, area_dict

我需要在布尔值数组中存储几个单元格。起初我使用numpy,但当数组开始占用大量内存时,我想到了在字典中存储非零元素,并将元组作为键(因为它是可哈希类型)。对于emaxple:
{(0,0,0):True,(1,2,3):True}
(这是“3D数组”中的两个单元格,索引为0,0,0和1,2,3,但维数事先未知,并在运行算法时定义)。 这很有帮助,因为非零单元只填充了阵列的一小部分

为了从这个dict中写入和获取值,我需要使用循环:

def fill_cells(indices, area_dict):
    for i in indices:
        area_dict[tuple(i)] = 1

def get_cells(indices, area_dict):
    n = len(indices)
    out = np.zeros(n, dtype=np.bool)
    for i in range(n):
        out[i] = tuple(indices[i]) in area_dict.keys()
    return out
现在我需要用麻木加快速度。Numba不支持本机Python的dict(),所以我使用了Numba.typed.dict。 问题是,在定义函数的阶段,Numba想知道键的大小,所以我甚至无法创建字典(键的长度事先未知,在调用函数时定义):

Numba无法正确推断字典键的类型,并返回错误:

Compilation is falling back to object mode WITH looplifting enabled because Function "make_dict" failed type inference due to: Invalid use of Function(<built-in function mul>) with argument(s) of type(s): (tuple(int64 x 1), int64)
但我认为这是错误的。我需要用@njit decorator使用fill_cells和get_cells函数,但是Numba返回了相同的错误,因为我试图在这个函数中从numpy数组创建元组

我理解Numba的基本限制(以及一般的编译),但也许有某种方法可以加快函数的速度,或者,也许您有另一种解决我的单元存储问题的方法?

最终解决方案: 主要问题是,在定义创建元组的函数时,Numba需要知道元组的长度。诀窍是每次重新定义函数。我需要用定义函数的代码生成字符串,并用exec()运行它。

n = 10
s = '@njit\ndef arr_to_tuple(a):\n\treturn (' + ''.join('a[%i],' % i for i in range(n)) + ')'
exec(s)
之后,我可以调用arr\u to\u tuple(a)来创建可以在另一个@njit修饰函数中使用的元组

例如,创建元组键的空字典,这需要解决问题:

@njit
def make_empty_dict():
    tpl = arr_to_tuple(np.array([0]*5))
    out = {tpl:True}
    del out[tpl]
    return out
我在字典中写了一个元素,因为这是Numba推断类型的方法之一

此外,我需要使用问题中描述的填充单元格获取单元格功能。这就是我用Numba重写它们的方式:

写作元素。刚刚将tuple()更改为arr\u to\u tuple():

从字典中获取元素需要一些令人毛骨悚然的代码:

@njit
def get_cells_nb(indices, area_dict):
    n = len(indices)
    out = np.zeros(n, dtype=np.bool_)
    for i in range(n):
        new_len = len(area_dict)
        tpl = arr_to_tuple(indices[i])
        area_dict[tpl] = True
        old_len = len(area_dict)
        if new_len == old_len:
            out[i] = True
        else:
            del area_dict[tpl]
    return out
我的版本的Numba(0.46)不支持.contains(in)运算符,请尝试使用EXPECT构造。如果您有支持它的版本,您可以为它编写更多的“常规”解决方案

因此,当我想检查字典中是否存在具有某个索引的元素时,我会记住它的长度,然后在字典中写下具有所述索引的内容。如果长度改变了,我就认为元素不存在。否则该元素将存在。看起来解决方案很慢,但事实并非如此

速度测试: 解决方案出奇地快。与本机Python优化代码相比,我使用%timeit对其进行了测试:

  • arr\u to\u tuple()比常规tuple()函数快5倍
  • 使用numba获取单元格与本机Python编写的获取单元格相比,一个元素快3倍,大型元素数组快40倍
  • 使用numba填充单元格与本机Python编写的填充单元格相比,一个元素快4倍,大型元素数组快40倍

  • 你考虑过稀疏矩阵吗?@Marat是的,我根据键字典(函数fill_cells和get_cells是这个实现的一部分)自己实现了稀疏矩阵。我意识到这是稀疏矩阵的常见解决方案。问题是我需要加快这个实现。另外,我不需要对它进行矩阵运算,只需要存储和获取值,也许它可以扩展可能的解决方案集。像DICT这样的原生数据结构效率很低。scipy.sparse提供了C实现,它的性能可能比本机结构高出一个数量级。@Marat是的,我发现scipy.sparse比我的解决方案快,但它只适用于2D矩阵。我需要处理任意维度。我还没有找到比自己写并用Numba加快速度更好的解决方案(这就是我现在要做的,我在问题中描述了问题)。你以前见过吗?当我想使用数组作为矩阵的索引时,我遇到了同样的问题。它只需要在顶部进行一个小的调整,因为dict没有ndim。您是否将性能与键入的列表进行了比较?似乎您实际上不需要存储“True”,因为存储索引已经暗示了这一点。还可以考虑编写<代码> unRaveLodeIndex <代码>和<代码> RavuluMulthIndex 函数,类似于NUMPY,使索引始终存储1D。@ RutGrkases,我在评论后比较了它。使用类型化列表的速度要慢得多,因为需要在循环中检查列表的元素才能进行检查。函数的执行时间取决于列表的大小,与Dict相反,Dict由于散列键而具有恒定的时间。顺便说一下,我发现对于这样的问题,使用本机Python hash()函数并在循环中使用它而不是ravel_multi_索引可能是合理的(更快)。这个循环可以用@njit修饰,而不需要对代码进行重大更改。
    @njit
    def make_empty_dict():
        tpl = arr_to_tuple(np.array([0]*5))
        out = {tpl:True}
        del out[tpl]
        return out
    
    @njit
    def fill_cells_nb(indices, area_dict):
        for i in range(len(indices)):
            area_dict[arr_to_tuple(indices[i])] = True
    
    @njit
    def get_cells_nb(indices, area_dict):
        n = len(indices)
        out = np.zeros(n, dtype=np.bool_)
        for i in range(n):
            new_len = len(area_dict)
            tpl = arr_to_tuple(indices[i])
            area_dict[tpl] = True
            old_len = len(area_dict)
            if new_len == old_len:
                out[i] = True
            else:
                del area_dict[tpl]
        return out