Python 如何在TensorFlow中将tf.equal()与shape=(1,1)的标签张量一起使用?

Python 如何在TensorFlow中将tf.equal()与shape=(1,1)的标签张量一起使用?,python,python-3.x,tensorflow,Python,Python 3.x,Tensorflow,我正在尝试评估单个图像。到目前为止,它是有效的。我得到每个类的个体概率和正确的标签。但是当我尝试使用tf.argmax(标签,1)获取类时,我总是得到类“0” 对于每个数据点,我总是得到“0”。我想继续使用tf.equal()但是标签的argmax值错误,当然这是不可能的。有什么想法吗 ... image, label = ... logits = model(image) arg_log = tf.argmax(logits, 1) arg_lbl = tf.argmax(label, 1)

我正在尝试评估单个图像。到目前为止,它是有效的。我得到每个类的个体概率和正确的标签。但是当我尝试使用
tf.argmax(标签,1)
获取类时,我总是得到类“0”

对于每个数据点,我总是得到“0”。我想继续使用
tf.equal()
但是标签的argmax值错误,当然这是不可能的。有什么想法吗

...
image, label = ...
logits = model(image)
arg_log = tf.argmax(logits, 1)
arg_lbl = tf.argmax(label, 1) # What must i change here?
cor_pre = tf.equal(arg_log, arg_lbl)
...
编辑 我每次都得到“0”,因为我得到了索引!我把问题改为: 如何在TensorFlow中对形状为(1,1)的张量使用tf.argmax()? 致: 如何在TensorFlow中将tf.equal()与形状=(1,1)的标签张量一起使用?

基于,
tf.argmax()
采用
输入和
轴以及其他参数

如果您的标签具有形状[1,1],您希望从轴1上的argmax得到什么?只有一个条目

很可能,您希望将标签与argmaxed结果进行比较。因此:

...
image, label = ...
# label: Tensor("..", shape=(1, 1), dtype=int32)
logits = model(image)
# logits: Tensor("..", shape=(1, 10), dtype=float32)
predic = tf.nn.softmax(logits)
arg_log = tf.argmax(logits, 1)
...
pre, lbl, a_log, a_lbl = sess.run([predic, label, arg_log, arg_lbl])
cor_pre = tf.equal(arg_log, tf.cast(label, tf.int64))

形状(1,1)的任何数组将只包含一个元素。一个元素必须是数组中的最大元素。

我找到了一个解决方案,在
tf.equal()中使用它之前,将标签强制转换为
int64


标签是如何定义的?我更改了问题:)tf.argmax()返回最大值的索引,因此对于形状张量(1,1),索引总是0:D“标签定义”是什么意思?好的。那么您是在寻找大小为1的数组中的最大值?我想你不应该感到惊讶,这是我问这个问题半分钟后意识到的。但是我如何使用tf呢。现在相等了?我明白了!我必须将它转换为int64!我应该删除整个问题吗?我已经更新了我的解决方案,将强制转换包含到int64中——希望现在有了更完整的解释。
...
image, label = ...
# label: Tensor("..", shape=(1, 1), dtype=int32)
logits = model(image)
# logits: Tensor("..", shape=(1, 10), dtype=float32)
predic = tf.nn.softmax(logits)
arg_log = tf.argmax(logits, 1)
...
pre, lbl, a_log, a_lbl = sess.run([predic, label, arg_log, arg_lbl])
cor_pre = tf.equal(arg_log, tf.cast(label, tf.int64))
...
image, label = ...
logits = model(image)
cor_pre = tf.equal(tf.argmax(logits, 1), tf.cast(label, tf.int64))
...