Python 利用K-最近邻聚类的有效方法

Python 利用K-最近邻聚类的有效方法,python,opencv,numpy,machine-learning,Python,Opencv,Numpy,Machine Learning,我试图将图像上的颜色聚类到预定义的类(黑、白、蓝、绿、红)。我正在使用以下代码: import numpy as np import cv2 src = cv2.imread('objects.png') colors = np.array([[0x00, 0x00, 0x00], [0xff, 0xff, 0xff], [0xff, 0x00, 0x00], [0x00, 0x

我试图将图像上的颜色聚类到预定义的类(黑、白、蓝、绿、红)。我正在使用以下代码:

import numpy as np
import cv2

src = cv2.imread('objects.png')

colors = np.array([[0x00, 0x00, 0x00],
                   [0xff, 0xff, 0xff],
                   [0xff, 0x00, 0x00],
                   [0x00, 0xff, 0x00],
                   [0x00, 0x00, 0xff]], dtype=np.float32)
classes = np.array([[0], [1], [2], [3], [4]], np.float32)
dst = np.zeros(src.shape, np.float32)

knn = cv2.KNearest()
knn.train(colors, classes)

# This loop is very inefficient!
for i in range(0, src.shape[0]):
    for j in range(0, src.shape[1]):
        sample = np.reshape(src[i,j], (-1,3)).astype(np.float32)
        retval, result, neighbors, dist = knn.find_nearest(sample, 1)
        dst[i,j] = colors[result[0,0]]

cv2.imshow('src', src)
cv2.imshow('dst', dst)
cv2.waitKey()
代码运行良好,结果如下所示。左边的图像是输入,右边的图像是输出


然而,上面的循环效率很低,并且使转换速度很慢。替换上述循环的最有效的Numpy操作是什么?

如果您想要一个简单的平方差度量(“这是欧几里德最接近的数字”),这将起作用

计算差异

diff = ((src[:,:,:,None] - colors.T)**2).sum(axis=2)
(假设
src
的形状为y,x,3)

选择最接近的颜色索引:

index = diff.argmin(axis=2)
新图像:

out = colors[index]
如果颜色的分量值真的为0或0xff,可以使用

out = np.where(src>0x88, 0xff, 0)

你可以建立一个查找表。这样你就可以知道每个颜色的对应类。它不一定是256x256x256。你可以减少一些容器。

我设法用下面的代码删除循环。代码运行非常快,几乎与C++版本类似。

import numpy as np
import cv2

src = cv2.imread('objects.png')
src_flatten = np.reshape(np.ravel(src, 'C'), (-1, 3))
dst = np.zeros(src.shape, np.float32)

colors = np.array([[0x00, 0x00, 0x00],
                   [0xff, 0xff, 0xff],
                   [0xff, 0x00, 0x00],
                   [0x00, 0xff, 0x00],
                   [0x00, 0x00, 0xff]], dtype=np.float32)
classes = np.array([[0], [1], [2], [3], [4]], np.float32)

knn = cv2.KNearest()
knn.train(colors, classes)
retval, result, neighbors, dist = knn.find_nearest(src_flatten.astype(np.float32), 1)

dst = colors[np.ravel(result, 'C').astype(np.uint8)]
dst = dst.reshape(src.shape).astype(np.uint8)

cv2.imshow('src', src)
cv2.imshow('dst', dst)
cv2.waitKey()
代码像以前一样生成正确的结果,执行时间更快


后一个建议还将为您提供0xff和0的所有8种颜色组合,例如(0xff,0xff,0),而不仅仅是您指定的5种颜色组合。