Python Numpy:如何高效地获取每行的topN元素?
我尝试为每一行获取(索引、值)元组的topN列表 heapq过程中仅使用单芯, 然后我尝试使用多处理,但时间消耗更长 有没有更快的方法获得结果 谢谢Python Numpy:如何高效地获取每行的topN元素?,python,numpy,sorting,Python,Numpy,Sorting,我尝试为每一行获取(索引、值)元组的topN列表 heapq过程中仅使用单芯, 然后我尝试使用多处理,但时间消耗更长 有没有更快的方法获得结果 谢谢 import heapq import multiprocessing import numpy import time class C1: def __init__(self): self.data = numpy.random.rand(100, 50000) self.top_n = 5000
import heapq
import multiprocessing
import numpy
import time
class C1:
def __init__(self):
self.data = numpy.random.rand(100, 50000)
self.top_n = 5000
def run_normal(self):
output = []
for item_index in range(self.data.shape[0]):
objs = heapq.nlargest(self.top_n, enumerate(self.data[item_index]), lambda x: x[1])
output.append(objs)
def run_mp(self):
with multiprocessing.Pool() as pool:
output = pool.map(self.sort_arr, self.data.tolist())
def sort_arr(self, arr):
return heapq.nlargest(self.top_n, enumerate(arr), lambda x: x[1])
if __name__ == '__main__':
c1 = C1()
start = time.time()
c1.run_normal()
print(time.time() - start)
start = time.time()
c1.run_mp()
print(time.time() - start)
输出
3.2407033443450928#用于循环时间
12.38778534164429#多处理时间
使用numpy.argsort可以轻松获取前n行:
import numpy as np
data = np.random.rand(100, 50000)
top_n = 5000
indices = np.argsort(data)[:, :top_n]
top_data = data[:, indices]
这比直接在Python中执行迭代要快。要清楚地说明问题: 我们得到了一个包含数据点的mxnnumpy数组。我们希望获得一个M x k,其中每一行包含原始数组中的前k个值,并与原始行中的值索引配对 例如:对于[[1,2]、[4,3]、[5,6]]和k=1的输入,我们希望输出[[(0,1)]、[(1,3)]、[(0,5)] 解决方案 最好、最快的解决方案是使用本机numpy功能。策略是首先获取每行的顶级索引,然后从这些索引中获取元素,然后将两者合并到输出数组中
data = np.random(100, 50000) # large
k = 5
# Define the type of our output array elements: (int, float)
dt = np.dtype([('index', np.int32, 1), ('value', np.float64, 1)])
# Take the indices of the largest k elements from each row
top_k_inds = np.argsort(data)[:, -1:-k - 1:-1]
# Take the values at those indices
top_k = np.take_along_axis(data, top_k_inds, axis=-1)
# Stack the two together along a third axis (to get index-value pairs)
top_k_pairs = np.stack((top_k_inds, top_k), axis=2)
# Convert the type (otherwise we have the indices as floats)
top_k_pairs = top_k_pairs.astype(dt)
在我的机器上分别得到4.46和4.16,所以多处理也带来了改进,考虑使用<代码> TimeTime.TimeIT(C1.RuniMP,编号=1)测量逝去的时间。