如何将tensorflow张量与Python对象进行比较?

如何将tensorflow张量与Python对象进行比较?,python,tensorflow,Python,Tensorflow,我正在尝试将从tfrec文件加载的字符串标签张量转换为数字,以便进行一次热编码。其思想是使用numpy数组作为查找表,一旦命中,索引将返回并存储在张量中 然而,问题是张量不能直接与python对象进行比较。我尝试使用tf.map_fn枚举我的标签批次,并使用tf.cond进行比较,但没有成功: def elem_op(t): global all_labels for idx, lbl in enumerate(all_labels): lbl_tensor =

我正在尝试将从tfrec文件加载的字符串标签张量转换为数字,以便进行一次热编码。其思想是使用numpy数组作为查找表,一旦命中,索引将返回并存储在张量中

然而,问题是张量不能直接与python对象进行比较。我尝试使用tf.map_fn枚举我的标签批次,并使用tf.cond进行比较,但没有成功:

def elem_op(t):
    global all_labels
    for idx, lbl in enumerate(all_labels):
        lbl_tensor = tf.constant(lbl.encode())  # tensorflow stores string as bytes, so convert the Python string object to bytes tensor
        ret = tf.cond(tf.equal(lbl_tensor, t), lambda : idx, lambda : -1)
        if ret != -1:  # now this doesn't work because tf.cond returns a tensor
            return ret
    return -1

# labels is a tensor storing a batch of label strings
train_labels = tf.map_fn(fn=elem_op, elems=labels, dtype=tf.int32)
问题是tf.cond还返回一个张量,不能在“if”子句中使用。我想知道解决这个问题的方法是什么


谢谢

您必须在会话中计算张量,以获得其实际值

将条件更改为:

if sess.run(ret) != -1:
其中
sess
是您的
tf.Session
实例。例如:

sess = tf.Session()
同样,您可以运行:

sess.run(train_labels)

忘了提到这个例程将在“model_fn”中调用,因此没有显式的“session”对象,特别是在elem_op子例程中,并且我不能使用sess.run(ret)或ret.eval(session=sess)。忘了提到model_fn中没有显式的session对象,所以很遗憾这不起作用。我担心“如何将tensorflow张量与Python对象进行比较[没有会话]?”是“你没有”。这是一种方法——在将所有字符串标签传递给tensorflow之前将其预处理为整数,从而避免了tensor v.s.Python obj比较。但我想知道是否还有其他方法可以做到这一点。。?