Python 高效地获取numpy.partition和numpy.argpartition输出

Python 高效地获取numpy.partition和numpy.argpartition输出,python,arrays,numpy,Python,Arrays,Numpy,我正在使用python 3.6和numpy。我有一个n维数组。我需要在数组的最后一个维度上执行partition和argpartition。我显然可以调用这两个函数,但这感觉像是在浪费资源。有没有办法同时得到np.partition和np.argpartition的结果?应该有一种方法可以得到np.partition的结果,将我从np.argpartition中得到的索引应用到数组中,但是我现在没有看到它 谢谢大家! 获取那些argpartition索引,然后使用获取分区数组 因此,任何维数和任

我正在使用python 3.6和numpy。我有一个n维数组。我需要在数组的最后一个维度上执行partition和argpartition。我显然可以调用这两个函数,但这感觉像是在浪费资源。有没有办法同时得到np.partition和np.argpartition的结果?应该有一种方法可以得到np.partition的结果,将我从np.argpartition中得到的索引应用到数组中,但是我现在没有看到它


谢谢大家!

获取那些
argpartition
索引,然后使用获取分区数组

因此,任何维数和任意通用轴的通用数据数组的实现是这样的-

def partition_results(a, k, axis=-1):
    idx = np.argpartition(a, k, axis=axis)
    index_arr = list(np.ix_(*[range(i) for i in a.shape]))
    index_arr[axis] =  idx
    return idx, a[index_arr]
def partition_results_exclusive_way(a, k):
    idx = np.argpartition(a, k, axis=-1)
    part_arr = np.partition(a, k, axis=-1)
    return idx , part_arr
为我们提供了“分散”范围数组,以完成高级索引的任务。这些范围数组需要覆盖与
argpartition
索引数组中的轴长度相对应的所有维度,但最后一个维度除外,我们有这些
argpartition
索引本身。这种索引操作需要此设置

因此,通过使用对
np.argpartition
np.partition
的两个单独调用的方法,我们可以得到它,就像这样-

def partition_results(a, k, axis=-1):
    idx = np.argpartition(a, k, axis=axis)
    index_arr = list(np.ix_(*[range(i) for i in a.shape]))
    index_arr[axis] =  idx
    return idx, a[index_arr]
def partition_results_exclusive_way(a, k):
    idx = np.argpartition(a, k, axis=-1)
    part_arr = np.partition(a, k, axis=-1)
    return idx , part_arr
在下一节中,我们将使用它来比较性能和价值验证

示例运行和运行时测试-

In [496]: a = np.random.rand(20,20,20,20,20)

In [502]: A0, B0 = partition_results_exclusive_way(a, 10)

In [503]: A1, B1 = partition_results(a, 10)

In [504]: np.allclose(A0,A1)
Out[504]: True

In [505]: np.allclose(B0,B1)
Out[505]: True

In [506]: %timeit partition_results_exclusive_way(a, 10)
10 loops, best of 3: 92.6 ms per loop

In [507]: %timeit partition_results(a, 10)
10 loops, best of 3: 76 ms per loop
进一步分析性能数据,让我们分别计算
argpartition
partition
-

In [509]: %timeit np.argpartition(a, 10, axis=-1)
10 loops, best of 3: 49.6 ms per loop

In [510]: %timeit np.partition(a, 10, axis=-1)
10 loops, best of 3: 43.6 ms per loop

因此,
高级索引
操作花费了我们使用
np.partition
的一半左右。我们肯定在那里节省

获取那些
argpartition
索引,然后使用获取分区数组

因此,任何维数和任意通用轴的通用数据数组的实现是这样的-

def partition_results(a, k, axis=-1):
    idx = np.argpartition(a, k, axis=axis)
    index_arr = list(np.ix_(*[range(i) for i in a.shape]))
    index_arr[axis] =  idx
    return idx, a[index_arr]
def partition_results_exclusive_way(a, k):
    idx = np.argpartition(a, k, axis=-1)
    part_arr = np.partition(a, k, axis=-1)
    return idx , part_arr
为我们提供了“分散”范围数组,以完成高级索引的任务。这些范围数组需要覆盖与
argpartition
索引数组中的轴长度相对应的所有维度,但最后一个维度除外,我们有这些
argpartition
索引本身。这种索引操作需要此设置

因此,通过使用对
np.argpartition
np.partition
的两个单独调用的方法,我们可以得到它,就像这样-

def partition_results(a, k, axis=-1):
    idx = np.argpartition(a, k, axis=axis)
    index_arr = list(np.ix_(*[range(i) for i in a.shape]))
    index_arr[axis] =  idx
    return idx, a[index_arr]
def partition_results_exclusive_way(a, k):
    idx = np.argpartition(a, k, axis=-1)
    part_arr = np.partition(a, k, axis=-1)
    return idx , part_arr
在下一节中,我们将使用它来比较性能和价值验证

示例运行和运行时测试-

In [496]: a = np.random.rand(20,20,20,20,20)

In [502]: A0, B0 = partition_results_exclusive_way(a, 10)

In [503]: A1, B1 = partition_results(a, 10)

In [504]: np.allclose(A0,A1)
Out[504]: True

In [505]: np.allclose(B0,B1)
Out[505]: True

In [506]: %timeit partition_results_exclusive_way(a, 10)
10 loops, best of 3: 92.6 ms per loop

In [507]: %timeit partition_results(a, 10)
10 loops, best of 3: 76 ms per loop
进一步分析性能数据,让我们分别计算
argpartition
partition
-

In [509]: %timeit np.argpartition(a, 10, axis=-1)
10 loops, best of 3: 49.6 ms per loop

In [510]: %timeit np.partition(a, 10, axis=-1)
10 loops, best of 3: 43.6 ms per loop

因此,
高级索引
操作花费了我们使用
np.partition
的一半左右。我们肯定在那里节省

谢谢你的回答!但它在三维数组上引发了一个错误,indexer错误:形状不匹配:无法广播索引数组together@Ant您是只使用3D阵列,还是正在寻找通用的ndarray(dims的通用数量)解决方案?我正在寻找通用的ndarray;我不希望维度超过4或5,但它肯定可以大于3:)@Ant查看对通用ndarray解决方案的编辑。@Ant对此添加了一些注释。谢谢您的回答!但它在三维数组上引发了一个错误,indexer错误:形状不匹配:无法广播索引数组together@Ant您是只使用3D阵列,还是正在寻找通用的ndarray(dims的通用数量)解决方案?我正在寻找通用的ndarray;我不希望维度超过4或5,但它肯定可以大于3:)@Ant查看对通用ndarray解决方案的编辑。@Ant对其添加了一些注释。