Python Numpy:如何高效地获取每行的topN元素?

Python Numpy:如何高效地获取每行的topN元素?,python,numpy,sorting,Python,Numpy,Sorting,我尝试为每一行获取(索引、值)元组的topN列表 heapq过程中仅使用单芯, 然后我尝试使用多处理,但时间消耗更长 有没有更快的方法获得结果 谢谢 import heapq import multiprocessing import numpy import time class C1: def __init__(self): self.data = numpy.random.rand(100, 50000) self.top_n = 5000

我尝试为每一行获取(索引、值)元组的topN列表

heapq过程中仅使用单芯, 然后我尝试使用多处理,但时间消耗更长

有没有更快的方法获得结果

谢谢

import heapq
import multiprocessing
import numpy
import time


class C1:

    def __init__(self):
        self.data = numpy.random.rand(100, 50000)
        self.top_n = 5000

    def run_normal(self):
        output = []
        for item_index in range(self.data.shape[0]):
            objs = heapq.nlargest(self.top_n, enumerate(self.data[item_index]), lambda x: x[1])
            output.append(objs)

    def run_mp(self):
        with multiprocessing.Pool() as pool:
            output = pool.map(self.sort_arr, self.data.tolist())

    def sort_arr(self, arr):
        return heapq.nlargest(self.top_n, enumerate(arr), lambda x: x[1])


if __name__ == '__main__':
    c1 = C1()

    start = time.time()
    c1.run_normal()
    print(time.time() - start)

    start = time.time()
    c1.run_mp()
    print(time.time() - start)

输出

3.2407033443450928#用于循环时间
12.38778534164429#多处理时间

使用numpy.argsort可以轻松获取前n行:

import numpy as np
data = np.random.rand(100, 50000)
top_n = 5000

indices = np.argsort(data)[:, :top_n]
top_data = data[:, indices]

这比直接在Python中执行迭代要快。

要清楚地说明问题:

我们得到了一个包含数据点的mxnnumpy数组。我们希望获得一个M x k,其中每一行包含原始数组中的前k个值,并与原始行中的值索引配对

例如:对于[[1,2]、[4,3]、[5,6]]和k=1的输入,我们希望输出[[(0,1)]、[(1,3)]、[(0,5)]

解决方案 最好、最快的解决方案是使用本机numpy功能。策略是首先获取每行的顶级索引,然后从这些索引中获取元素,然后将两者合并到输出数组中

data = np.random(100, 50000)  # large
k = 5

# Define the type of our output array elements: (int, float)
dt = np.dtype([('index', np.int32, 1), ('value', np.float64, 1)])

# Take the indices of the largest k elements from each row
top_k_inds = np.argsort(data)[:, -1:-k - 1:-1]

# Take the values at those indices
top_k = np.take_along_axis(data, top_k_inds, axis=-1)

# Stack the two together along a third axis (to get index-value pairs)
top_k_pairs = np.stack((top_k_inds, top_k), axis=2)

# Convert the type (otherwise we have the indices as floats)
top_k_pairs = top_k_pairs.astype(dt)

在我的机器上分别得到4.46和4.16,所以多处理也带来了改进,考虑使用<代码> TimeTime.TimeIT(C1.RuniMP,编号=1)测量逝去的时间。