Python 如何使用Pytorch和/或Numpy有效地找到多维矩阵数组中最大值的索引 背景

Python 如何使用Pytorch和/或Numpy有效地找到多维矩阵数组中最大值的索引 背景,python,numpy,pytorch,max,numba,Python,Numpy,Pytorch,Max,Numba,在机器学习中,处理高维数据是很常见的。例如,在卷积神经网络(CNN)中,每个输入图像的尺寸可以是256x256,并且每个图像可以具有3个颜色通道(红色、绿色和蓝色)。如果我们假设模型一次接收16幅图像,那么进入CNN的输入的维度是[16,3256256]。每个单独的卷积层期望数据的形式为[批次大小,在通道中,在y中,在x中],并且所有这些数量通常在层与层之间变化(批次大小除外)。我们用于表示由[in_y,in_x]值组成的矩阵的术语是特征图,该问题涉及在给定层的每个特征图中查找最大值及其索引 我

在机器学习中,处理高维数据是很常见的。例如,在卷积神经网络(CNN)中,每个输入图像的尺寸可以是256x256,并且每个图像可以具有3个颜色通道(红色、绿色和蓝色)。如果我们假设模型一次接收16幅图像,那么进入CNN的输入的维度是
[16,3256256]
。每个单独的卷积层期望数据的形式为
[批次大小,在通道中,在y中,在x中]
,并且所有这些数量通常在层与层之间变化(批次大小除外)。我们用于表示由
[in_y,in_x]
值组成的矩阵的术语是特征图,该问题涉及在给定层的每个特征图中查找最大值及其索引

我为什么要这样做?我想对每个特征贴图应用一个遮罩,我想在每个特征贴图中以最大值为中心应用遮罩,要做到这一点,我需要知道每个最大值的位置。此掩码应用程序在模型的训练和测试期间完成,因此效率对于减少计算时间至关重要。有许多Pytorch和Numpy解决方案可用于查找单个最大值和索引,以及查找单个维度上的最大值或索引,但没有(我可以找到)专用且高效的内置函数可用于一次查找两个或更多维度上的最大值索引。是的,我们可以嵌套在单个维度上运行的函数,但这些是一些效率最低的方法

我试过的
  • 我已经看过了,但是作者正在处理一个特殊情况的4D数组,它被简单地压缩成一个3D数组。被接受的答案专门针对这种情况,而指向TopK的答案被误导了,因为它不仅在单一维度上运行,而且在给定问题时需要
    k=1
    ,从而解决常规
    torch.max
    呼叫
  • 我已经看过了,但是这个问题和它的答案,集中在一个维度上
  • 我已经看过了,但我已经知道答案的方法,因为我在自己的答案中独立地阐述了它(我在其中修正了该方法是非常低效的)
  • 我已经看过了,但它不能满足这个问题的关键部分,即效率问题
  • 我已经阅读了许多其他Stackoverflow问题和答案,以及Numpy文档、Pytorch文档和Pytorch论坛上的帖子
  • 我已经尝试了很多不同的方法来解决这个问题,我已经提出了这个问题,这样我就可以回答这个问题并回馈给社区,以及任何在未来寻求解决这个问题的方法的人
绩效标准 如果我问一个关于效率的问题,我需要清楚地详细说明期望值。我正试图为上述问题找到一个时间效率高的解决方案(空间是次要的),而不必编写C代码/扩展,而且它相当灵活(我所追求的不是超专业化的方法)。该方法必须接受数据类型为float32或float64的
[a,b,c,d]
Torch张量作为输入,并输出数据类型为int32或int64的
[a,b,2]
形式的数组或张量(因为我们将输出用作索引)。 应根据以下典型解决方案对解决方案进行基准测试:

max_indices = torch.stack([torch.stack([(x[k][j]==torch.max(x[k][j])).nonzero()[0] for j in range(x.size()[1])]) for k in range(x.size()[0])])
@njit(cache=True)
def indexFunc(array, item):
    for idx, val in np.ndenumerate(array):
        if val == item:
            return idx
    raise RuntimeError

@njit(cache=True, parallel=True)
def indexFunc2(x,maxVals):
    max_indices = np.zeros((x.shape[0],x.shape[1],2),dtype=np.int64)
    for i in prange(x.shape[0]):
        for j in prange(x.shape[1]):
            max_indices[i,j] = np.asarray(indexFunc(x[i,j], maxVals[i,j]),dtype=np.int64)
    return max_indices

x = x.numpy()
maxVals = np.amax(x, axis=(2,3))
max_indices = torch.from_numpy(indexFunc2(x,maxVals))
方法 我们将利用Numpy社区和库,以及Pytorch张量和Numpy数组可以相互转换的事实,而无需复制或移动内存中的底层数组(因此转换成本较低)。从:

将火炬张量转换为Numpy数组或将火炬张量转换为Numpy数组是轻而易举的事。torch张量和Numpy数组将共享它们的底层内存位置,更改其中一个将更改另一个

解决方案一 我们首先将使用编写一个函数,该函数在第一次使用时将被实时(JIT)编译,这意味着我们可以获得C速度,而无需自己编写C代码。当然,对于可以获得JIT ed的内容有一些警告,其中一个警告是我们使用Numpy函数。但这并不是太糟糕,因为,记住,从火炬张量转换为Numpy的成本很低。我们创建的功能是:

@njit(cache=True)
def indexFunc(array, item):
    for idx, val in np.ndenumerate(array):
        if val == item:
            return idx
此函数如果来自另一个Stackoverflow答案(这是将我介绍给Numba的答案)。该函数采用N维Numpy数组,并查找给定
项的第一个匹配项。如果匹配成功,它会立即返回找到的项的索引。
@njit
decorator是
@jit(nopython=True)
的缩写,它告诉编译器我们希望它不使用Python对象编译函数,如果它不能这样做,就会抛出一个错误(当不使用Python对象时,Numba是最快的,速度是我们追求的)

有了这个快速函数的支持,我们可以得到张量中最大值的指数,如下所示:

import numpy as np

x =  x.numpy()
maxVals = np.amax(x, axis=(2,3))
max_indices = np.zeros((n,p,2),dtype=np.int64)
for index in np.ndindex(x.shape[0],x.shape[1]):
    max_indices[index] = np.asarray(indexFunc(x[index], maxVals[index]),dtype=np.int64)
max_indices = torch.from_numpy(max_indices)
我们使用
np.amax
,因为它可以接受其
参数的元组,允许它返回4D输入中每个2D特征映射的最大值。我们用
np.zero
提前初始化
max_索引
,因为这样我们就提前分配了所需的空间。这种方法比问题中的典型解决方案快得多(一个数量级),但它也在JIT ed函数外使用
for
循环,因此我们可以改进

解决方案二 我们将使用以下解决方案:

max_indices = torch.stack([torch.stack([(x[k][j]==torch.max(x[k][j])).nonzero()[0] for j in range(x.size()[1])]) for k in range(x.size()[0])])
@njit(cache=True)
def indexFunc(array, item):
    for idx, val in np.ndenumerate(array):
        if val == item:
            return idx
    raise RuntimeError

@njit(cache=True, parallel=True)
def indexFunc2(x,maxVals):
    max_indices = np.zeros((x.shape[0],x.shape[1],2),dtype=np.int64)
    for i in prange(x.shape[0]):
        for j in prange(x.shape[1]):
            max_indices[i,j] = np.asarray(indexFunc(x[i,j], maxVals[i,j]),dtype=np.int64)
    return max_indices

x = x.numpy()
maxVals = np.amax(x, axis=(2,3))
max_indices = torch.from_numpy(indexFunc2(x,maxVals))
我们可以使用Numba的
prange
fu利用并行化,而不是使用
for
循环一次迭代一个功能映射
@njit('i8[:,:,:](f4[:,:,:,:])',cache=True, parallel=True)
def indexFunc4(x):
    max_indices = np.zeros((x.shape[0],x.shape[1],2),dtype=np.int64)
    for i in prange(x.shape[0]):
        for j in prange(x.shape[1]):
            maxTemp = np.argmax(x[i][j])
            max_indices[i][j] = [maxTemp // x.shape[2], maxTemp % x.shape[2]] 
    return max_indices    

max_indices6 = torch.from_numpy(indexFunc4(x))
@njit('i8[:,:,::1](f4[:,:,:,::1])',cache=True, parallel=True)
def indexFunc5(x):
    max_indices = np.zeros((x.shape[0],x.shape[1],2),dtype=np.int64)
    for i in prange(x.shape[0]):
        for j in prange(x.shape[1]):
            maxTemp = np.argmax(x[i][j])
            max_indices[i][j] = [maxTemp // x.shape[2], maxTemp % x.shape[2]] 
    return max_indices 

max_indices7 = torch.from_numpy(indexFunc5(x))