Python 使用round或argmax评估Tensorflow中的softmax(分类)精度? 问题

Python 使用round或argmax评估Tensorflow中的softmax(分类)精度? 问题,python,tensorflow,prediction,softmax,Python,Tensorflow,Prediction,Softmax,我尝试使用Tensorflow制作softmax分类器,并使用tf.argmax()进行预测。我发现其中一个y总是高于0.5,我使用了tf.round()而不是tf.argmax() 然而,这两种方法之间的准确度差距约为20%,tf.round()的准确度高于tf.argmax() 我希望这两种方法的精度应该完全相同,或者tf.round()应该低于tf.argmax(),但事实并非如此。有人知道为什么会这样吗 Y和Y_ 正确预测 精确 这个问题有点老了,但有人可能会像我一样绊倒在它上面 简言之

我尝试使用Tensorflow制作softmax分类器,并使用
tf.argmax()
进行预测。我发现其中一个
y
总是高于0.5,我使用了
tf.round()
而不是
tf.argmax()

然而,这两种方法之间的准确度差距约为20%,
tf.round()
的准确度高于
tf.argmax()

我希望这两种方法的精度应该完全相同,或者
tf.round()
应该低于
tf.argmax()
,但事实并非如此。有人知道为什么会这样吗

Y和Y_
正确预测 精确
这个问题有点老了,但有人可能会像我一样绊倒在它上面

简言之:

  • tf.argmax
    减少axis参数指定的一个维度,它将索引返回到沿该轴的最大值
  • tf.round
    对每个值进行舍入,返回与输入相同的形状
  • 使用
    tf.round
    的“精度”将随着轴指定尺寸的长度而增加
e、 g

输出:

argmax shape: (3,)
round shape: (3, 3)
argmax value [0 1 1]
round value [[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 0.]]
argmax accuracy: 0.6666667
round accuracy: 0.8888889
您可以看到,使用
tf.round
通常会为每行上的两个额外零获得分数

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))

correct_prediction = tf.equal(tf.round(y), tf.round(y_))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
y = [
    [1.0, 0.0, 0.0],
    [0.0, 1.0, 0.0],
    [0.0, 0.0, 1.0]
]

y_ = [
    [0.92, 0.4, 0.4],
    [0.2,0.6,0.2],
    [0.2,0.4,0.2]
]

print('argmax shape:', tf.argmax(y_, 1).shape)
print('round shape:', tf.round(y_).shape)

with tf.Session():
    print('argmax value', tf.argmax(y_, 1).eval())
    print('round value', tf.round(y_).eval())

    correct_prediction_argmax = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
    correct_prediction_round = tf.equal(tf.round(tf.cast(y, tf.float32)), tf.round(y_))
    print('argmax accuracy:', tf.reduce_mean(tf.cast(correct_prediction_argmax, tf.float32)).eval())
    print('round accuracy:', tf.reduce_mean(tf.cast(correct_prediction_round, tf.float32)).eval())
argmax shape: (3,)
round shape: (3, 3)
argmax value [0 1 1]
round value [[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 0.]]
argmax accuracy: 0.6666667
round accuracy: 0.8888889