Python 使用numpy获取查找表索引的最快方法

Python 使用numpy获取查找表索引的最快方法,python,arrays,numpy,indexing,lookup,Python,Arrays,Numpy,Indexing,Lookup,接下来是这个问题,旨在加速以下代码。我(在帮助下)构建了一些代码,从mxnx3numpy.ndarray(RGB图像)中获取像素值,并将像素值与(位置)查找表中的像素值进行比较,并输出查找表的像素值索引,如下所示: img = np.random.randint(0, 9 , (3, 3, 3)) lut2 = img[0,0:2,:] output = [] for x in xrange(lut2.shape[0]): if lut2[x] in img:

接下来是这个问题,旨在加速以下代码。我(在帮助下)构建了一些代码,从mxnx3
numpy.ndarray
(RGB图像)中获取像素值,并将像素值与(位置)查找表中的像素值进行比较,并输出查找表的像素值索引,如下所示:

img = np.random.randint(0, 9 , (3, 3, 3))
lut2 = img[0,0:2,:]

output = []
for x in xrange(lut2.shape[0]):
    if lut2[x] in img:
            output.append(np.concatenate(np.where( (img == lut2[x]).sum(axis=2) == 3 )))

print np.array(output)
输出: 错误:如果在
img
中的多个位置获得
lut[x]
,将无法正确输出。也许更好的策略是将输出格式化为嵌套列表,因为任何可索引的都可以。例如:

[[x_px1, y_px1], [[x_px2a, y_px2a], [x_px2b, y_px2b]]]
其中
px2a
px2b
img
中具有相同颜色值的独立像素

我一直在摆弄
numpy
指数化,但我没有什么比这更好的了。虽然上述方法部分有效,但由于在阵列上进行迭代,它的速度当然非常慢。如果我给它一个普通大小的图像,它将需要一个不可接受的时间来完成

有人能告诉我一个更快的解决方案吗?


编辑:我的理想未来算法摘要: 最后一个
m X n
数组将用于两件事:

  • 通过将数组存储在数据库中,以不同的时间间隔绘制数据
  • 对由于某些外部事件而导致的阵列变化进行统计分析
注意:基本图像已经是一种绘图,但我无法访问数值数据,因此我需要反转绘图


一些(可选)要求对我的代码的目标进行精确说明: 查找表通常基于像素值,而不是位置。不过,在这种情况下,位置会更好(我认为)。上面代码中的
lut
数组是表示速度的线性色标的虚拟,我需要将其数字化,如下所示:

由于每个级别的刻度都是统一的颜色,因此实际上可以将刻度缩小为一组尺寸
1 X高度X 3
(适用于RGB、HSV等)。事实上,我可以去掉第一个维度,然后迭代一个
高x3
的数组。我用来数字化比例的东西有:

  • 它以零为中心
  • 它是对称的
  • 它是线性的
  • 已知最大值和最小值
因此,该比例的数字表示为:

digital_scale = np.linspace(-max_speed, max_speed, lut.shape[0])

我只需要获得我在
lut
中查找的像素值的
y
索引,然后是
digital\u scale
y
第个元素,以输出我需要的值(标量,请参见上面的算法摘要)。

您可以使用Kd tree,这里是一个演示:

import numpy as np
from scipy import spatial

H, W = 200, 100

np.random.seed(1)
a = np.random.randint(0, 20, (H, W, 3))
b = np.random.randint(0, 20, (20, 3))
tree = spatial.cKDTree(a.reshape(-1, 3))
res = tree.query_ball_point(b, 0.5, p=1)
print res
输出为:

[[] [577, 17471] [14636, 4515, 13693, 10988, 15013] [16935, 8576, 13286]
 [2443] [7743, 5914] [] [7469, 19736, 13395, 14992, 9083, 15514]
 [1167, 11416] [3903, 4968] [16504, 2996, 10805, 2264] [] [6725]
 [14437, 5888] [17667] [4681, 2545, 6442] [15067, 4533]
 [7876, 2235, 10152, 3288] [15404, 5691, 17216]
 [15586, 9916, 16938, 15931, 4828, 4069]]
要查看索引2的结果:

rows, cols = np.where(np.all(a == b[None, None, 2], axis=-1))
assert np.all(rows * W + cols == sorted(res[2]))

您想要的实际输出是什么(即,您肯定是在以某种数组格式查找数据,而不是将其打印出来)?你到底想干什么?LUT通常是图像值的函数,而不是索引。是否需要最终输出作为数据数组,它可以是其他的吗?如果LUT级别有多个像素,那么输出看起来如何?出于测试目的,您可能需要使用
np.random.seed(1);img=np.random.randint(0,9,(3,3,3));img[1,0,:]=img[0,0,:]
@wwii感谢您的关注。最好将输出作为数据数组,但我可以使用任何可以索引的东西,例如列表。这样做的原因是我想使用索引获取
digital\u scale
的相应元素,这将是我的最终结果(请参见上面的编辑)。是的,任何LUT级别都可能有超过1个像素(这是我的代码失败的原因,请参阅上面提到的错误,我希望输出为此类情况下的嵌套列表或数组。这似乎非常复杂。我理解结果,但不理解内部内容,即使在阅读文档时也是如此。我将尽快尝试(仍然是本周)!感谢您的回答:-)
[[] [577, 17471] [14636, 4515, 13693, 10988, 15013] [16935, 8576, 13286]
 [2443] [7743, 5914] [] [7469, 19736, 13395, 14992, 9083, 15514]
 [1167, 11416] [3903, 4968] [16504, 2996, 10805, 2264] [] [6725]
 [14437, 5888] [17667] [4681, 2545, 6442] [15067, 4533]
 [7876, 2235, 10152, 3288] [15404, 5691, 17216]
 [15586, 9916, 16938, 15931, 4828, 4069]]
rows, cols = np.where(np.all(a == b[None, None, 2], axis=-1))
assert np.all(rows * W + cols == sorted(res[2]))