Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/python/343.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python sklearn的GMM性能异常差_Python_Scikit Learn_Classification_Cluster Analysis - Fatal编程技术网

Python sklearn的GMM性能异常差

Python sklearn的GMM性能异常差,python,scikit-learn,classification,cluster-analysis,Python,Scikit Learn,Classification,Cluster Analysis,我试图使用scikitlearn的DPGMM分类器对一些模拟数据进行建模,但性能不佳。下面是我正在使用的示例: from sklearn import mixture import numpy as np import matplotlib.pyplot as plt clf = mixture.DPGMM(n_components=5, init_params='wc') s = 0.1 a = np.random.normal(loc=1, scale=s, size=(1000,)) b

我试图使用scikitlearn的DPGMM分类器对一些模拟数据进行建模,但性能不佳。下面是我正在使用的示例:

from sklearn import mixture
import numpy as np
import matplotlib.pyplot as plt
clf = mixture.DPGMM(n_components=5, init_params='wc')
s = 0.1
a = np.random.normal(loc=1, scale=s, size=(1000,))
b = np.random.normal(loc=2, scale=s, size=(1000,))
c = np.random.normal(loc=3, scale=s, size=(1000,))
d = np.random.normal(loc=4, scale=s, size=(1000,))
e = np.random.normal(loc=7, scale=s*2, size=(5000,))
noise = np.random.random(500)*8 
data = np.hstack([a,b,c,d,e,noise]).reshape((-1,1))
clf.means_ = np.array([1,2,3,4,7]).reshape((-1,1))
clf.fit(data)
labels = clf.predict(data)
plt.scatter(data.T, np.random.random(len(data)), c=labels, lw=0, alpha=0.2)
plt.show()

我认为这正是高斯混合模型所能解决的问题。我尝试过使用alpha,使用gmm而不是dpgmm,更改起始组件的数量等等。我似乎无法获得可靠和准确的分类。有什么我只是错过了?还有其他更合适的模型吗?

因为您没有足够长的迭代时间使其收敛

检查

clf.converged_
并尝试将
n_iter
增加到
1000

但是,请注意,
DPGMM
在这个数据集上仍然失败得很惨,最终将集群数量减少到只有2个