Python 如何提高Sklearn GMM predict()的性能速度?

Python 如何提高Sklearn GMM predict()的性能速度?,python,scikit-learn,multiprocessing,Python,Scikit Learn,Multiprocessing,我正在使用Sklearn对一些数据估计高斯混合模型(GMM) 估计之后,我有很多查询点。我想获得它们属于每个估计高斯分布的概率 下面的代码可以工作。但是,gmm_sk.predict_proba(query_points)部分非常慢,因为我需要在100000组样本上运行多次,其中每个样本包含1000个点 我想这是因为它是连续的。有没有办法使之平行?或者有没有其他方法让它更快?也许在GPU上使用TensorFlow 我看到TensorFlow有自己的GMM算法,但很难实现 以下是我编写的代码: i

我正在使用Sklearn对一些数据估计高斯混合模型(GMM)

估计之后,我有很多查询点。我想获得它们属于每个估计高斯分布的概率

下面的代码可以工作。但是,
gmm_sk.predict_proba(query_points)
部分非常慢,因为我需要在100000组样本上运行多次,其中每个样本包含1000个点

我想这是因为它是连续的。有没有办法使之平行?或者有没有其他方法让它更快?也许在GPU上使用TensorFlow

我看到TensorFlow有自己的GMM算法,但很难实现

以下是我编写的代码:

import numpy as np
from sklearn.mixture import GaussianMixture
import time


n_gaussians = 1000
covariance_type = 'diag'
points = np.array(np.random.rand(10000, 3), dtype=np.float32)
query_points = np.array(np.random.rand(1000, 3), dtype=np.float32)
start = time.time()

#GMM with sklearn
gmm_sk = GaussianMixture(n_components = n_gaussians, covariance_type=covariance_type)
gmm_sk.fit(points)
mid_t = time.time()
elapsed = time.time() - start
print("learning took "+ str(elapsed))

temp = []
for i in range(2000):
    temp.append(gmm_sk.predict_proba(query_points))

end_t = time.time() - mid_t
print("predictions took " + str(end_t))    
我解决了!使用
多处理
。 刚刚更换

temp = []
for i in range(2000):
    temp.append(gmm_sk.predict_proba(query_points))


如果使用“对角”或球形协方差矩阵而不是完全协方差矩阵进行拟合,则可以加快过程

使用:

协方差\u type='diag'

协方差

内部
GaussianMixture

另外,尝试减少高斯分量


但是,请记住,这可能会影响结果,但我看不到其他加速过程的方法。

我看到GMM中高斯分量的数量是1000,我认为这是一个非常大的数字,因为您的数据维度相对较低(3)。这可能是它运行缓慢的原因,因为它需要计算1000个独立的高斯。如果您的样本数很低,那么这也很容易过度拟合。您可以尝试使用较少数量的组件,这自然会更快,并且很可能会更好地推广。

注意,我已经使用了
'diag'
。它仍然很慢。也许并行化?@itzikBenShabat高斯混合函数没有n_jobs参数(用于计算的CPU数量),但如果您找到其他方法,post-it-cause也会很有趣:),考虑使用其他模块,例如TysFooFor,甚至其他软件来执行此任务。Matlab@itzikBenShabat你试过减少高斯分量吗?这是个限制。它必须保持1000。@itzikBenShabat哦,好的!跑步需要多长时间?我将在我的笔记本电脑上运行它,并让你知道谢谢你的建议,但高斯数是一个我无法改变的约束,那么也许你可以天真地对样本集进行并行化,例如,在不同的线程上运行每一组。当然,如果必须将结果写入共享容器,则需要小心。
import multiprocessing as mp
    query_points = query_points.tolist()
    parallel = mp.Pool()
    fv = parallel.map(par_gmm, query_points)
    parallel.close()
    parallel.join()