Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/python/279.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python 在Tensorflow中使用惰性条件句_Python_Python 3.x_Tensorflow - Fatal编程技术网

Python 在Tensorflow中使用惰性条件句

Python 在Tensorflow中使用惰性条件句,python,python-3.x,tensorflow,Python,Python 3.x,Tensorflow,如果您对某些昂贵的操作有条件,您可能需要惰性行为,即仅计算所选的分支 以下选项可以工作,并且是懒惰的: >>> a. tf.zeros(0) >>> tf.cond(tf.equal(tf.size(a), tf.constant(0)), lambda: tf.constant(-1, dtype=tf.int64), lambda: tf.argmax(a)).eval() -1 您可以看到它是懒惰的,因为argmax没有被计算,因为它会导致错误。因为a

如果您对某些昂贵的操作有条件,您可能需要惰性行为,即仅计算所选的分支

以下选项可以工作,并且是懒惰的:

>>> a. tf.zeros(0)
>>> tf.cond(tf.equal(tf.size(a), tf.constant(0)), lambda: tf.constant(-1, dtype=tf.int64), lambda: tf.argmax(a)).eval()
-1
您可以看到它是懒惰的,因为argmax没有被计算,因为它会导致错误。因为argmax上的张量是空的。如果将argmax移出lambda,则会产生以下错误:

>>> am = tf.argmax(a)
>>> tf.cond(tf.equal(tf.size(a), tf.constant(0)), lambda: tf.constant(-1, dtype=tf.int64), lambda: tf.add(am, 1)).eval()
... Reduction axis 0 is empty in shape [0]
这不是由
tf.add
操作引起的。将其内联移动并再次工作:

>>> tf.cond(tf.equal(tf.size(a), tf.constant(0)), lambda: tf.constant(-1, dtype=tf.int64), lambda: tf.add(tf.argmax(a), 1)).eval()
-1

那么,问题是如何以更干净的方式处理惰性条件?

当条件函数变长时,上述方法会变得有点混乱。您可以做的是在条件外定义lambda表达式请注意,以下内容在Python interactive REPL中不起作用,它会导致
值错误:操作“cond_14/Merge”已标记为不可获取。

当您将代码放入python文件并以正常方式运行时,它确实起作用

import tensorflow as tf

sess = tf.InteractiveSession()

a = tf.zeros(0)
fn = lambda: tf.argmax(a)

res = tf.cond(
    tf.equal(tf.size(a), tf.constant(0)),
    lambda: tf.constant(-1, dtype=tf.int64),
    fn
    ).eval()
print(res)

res2 = tf.cond(
    tf.equal(tf.size(a), tf.constant(0)),
    lambda: tf.constant(-1, dtype=tf.int64),
    lambda: tf.add(fn(), tf.constant(1, dtype=tf.int64))
    ).eval()
print(res2)
# Output:
# -1
# -1