Python 绘制ROC曲线时出现关键错误
我试图为多标签分类问题绘制ROC曲线。 在给目标列添加标签之前,我的目标列如下所示:Python 绘制ROC曲线时出现关键错误,python,roc,Python,Roc,我试图为多标签分类问题绘制ROC曲线。 在给目标列添加标签之前,我的目标列如下所示: TargetGrouped I5 I2 R0 I3 这是用于绘制ROC的代码的一部分: def computeROC(self, n_classes, y_test, y_score): # Compute ROC curve and ROC area for each class fpr = dict() tpr = dict() ro
TargetGrouped
I5
I2
R0
I3
这是用于绘制ROC的代码的一部分:
def computeROC(self, n_classes, y_test, y_score):
# Compute ROC curve and ROC area for each class
fpr = dict()
tpr = dict()
roc_auc = dict()
# Compute micro-average ROC curve and ROC area
fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_score.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
# Plot ROC curve
plt.figure()
plt.plot(fpr["micro"], tpr["micro"],
label='micro-average ROC curve (area = {0:0.2f})'
''.format(roc_auc["micro"]))
for i in range(n_classes):
plt.plot(fpr[i], tpr[i], label='ROC curve of class {0} (area = {1:0.2f})'
''.format(i, roc_auc[i]))
plt.plot([0, 1], [0, 1], 'k--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Some extension of Receiver operating characteristic to multi-class')
plt.legend(loc="lower right")
plt.show()
lb = preprocessing.LabelBinarizer()
infoDF = infoDF.join(pd.DataFrame(lb.fit_transform(infoDF["TargetGrouped"]), columns = lb.classes_, index = infoDF.index))
#Extracted features and split infoDF dataframe => got X_train, X_test, y_train, y_test
rfc = RandomForestClassifier()
rfc.fit(X_train, y_train)
y_pred = rfc.predict(X_test)
classes = y_train.shape[1]
computeROC(classes, y_test, y_pred)
当我运行它时,会出现以下错误:
Traceback (most recent call last):
File "<ipython-input-50-15a83ece5e44>", line 3, in <module>
evaluation.computeROC(classes, y_test, y_pred)
File "<ipython-input-49-526a19a07850>", line 18, in computeROC
plt.plot(fpr[i], tpr[i], label='ROC curve of class {0} (area = {1:0.2f})'
KeyError: 0
问题实际上与以下内容有关:
for i in range(n_classes):
plt.plot(fpr[i], tpr[i], label='ROC curve of class {0} (area = {1:0.2f})'
''.format(i, roc_auc[i]))
fpr
和tpr
是字典,您初始化的唯一键是'micro'
。这个for循环将0
和n_类-1
之间的整数值分配给i
,但是你从来没有定义过fpr[0]
和tpr[0]
是什么(我怀疑你认为它们是列表,但这只是猜测)。n_类包含什么?它是int
,它是32
for i in range(n_classes):
plt.plot(fpr[i], tpr[i], label='ROC curve of class {0} (area = {1:0.2f})'
''.format(i, roc_auc[i]))