Python MxNet metrics API用于计算带有向量标签的多类逻辑回归的精度
如何使用MxNet计算带有向量标签的多类逻辑回归分类器的精度? 以下是标签的示例:Python MxNet metrics API用于计算带有向量标签的多类逻辑回归的精度,python,logistic-regression,mxnet,multiclass-classification,Python,Logistic Regression,Mxnet,Multiclass Classification,如何使用MxNet计算带有向量标签的多类逻辑回归分类器的精度? 以下是标签的示例: Class1: [1,0,0,0] Class2: [0,1,0,0] Class3: [0,0,1,0] Class4: [0,0,0,1] 使用此函数的简单方法会产生错误的结果,因为argmax会将模型输出压缩为具有最大概率值的索引 def evaluate_accuracy(data_iterator, ctx, net): acc = mx.metric.Accuracy() for i
Class1: [1,0,0,0]
Class2: [0,1,0,0]
Class3: [0,0,1,0]
Class4: [0,0,0,1]
使用此函数的简单方法会产生错误的结果,因为argmax会将模型输出压缩为具有最大概率值的索引
def evaluate_accuracy(data_iterator, ctx, net):
acc = mx.metric.Accuracy()
for i, (data, label) in enumerate(data_iterator):
data = data.as_in_context(ctx)
label = label.as_in_context(ctx)
out = net(data)
p = nd.argmax(out, axis=1)
acc.update(preds=p, labels=label)
return acc.get()[1]
我目前的解决方案有点老套:
def evaluate_accuracy(data_iterator, ctx, net):
acc = mx.metric.Accuracy()
for i, (data, label) in enumerate(data_iterator):
data = data.as_in_context(ctx)
label = label.as_in_context(ctx)
out = net(data)
p = nd.argmax(out, axis=1)
l = nd.argmax(label, axis=1)
acc.update(preds=p, labels=l)
return acc.get()[1]
度量是棘手的。它并不能将一个热编码标签作为基本事实
我发现这有点违反直觉,但您需要传递非一个热编码标签作为基本事实,但实际的类(例如,2而不是[0,0,1,0])。否则,精确性将无法以您期望的方式工作。看看我先前的答覆:
此外,MxNet希望类以0开头。所以,如果你有从1开始的类,那么你需要通过减去1来调整所有类。谢谢Sergei。我想这应该是一个错误,因为SoftmaxCrossEntropyLoss允许标签具有概率分布,但metrics API没有考虑到这一点。@Sergey,顺便说一句,你也可以回答这个问题。看起来基于0的标签是原因之一