Python 使用tf.cond()时,Tensorflow要求输入不必要的占位符
考虑以下代码片段,其中包括tensorflowPython 使用tf.cond()时,Tensorflow要求输入不必要的占位符,python,tensorflow,machine-learning,Python,Tensorflow,Machine Learning,考虑以下代码片段,其中包括tensorflowtf.cond() 在这两种情况下,bb为False,zz的计算理论上不依赖于xx,但tensorflow仍然需要输入xx。尽管它可以作为虚拟数组提供,但它必须与yy的形状相匹配,并且不像dict2那样干净 有谁能建议如何在不为xx提供值的情况下计算zz(使用tf.cond()或任何其他方法)?您可以将xx定义为tf.Variable,并给它一个默认值(只要xx没有输入另一个值,就会使用该值)。需要注意的几件事: 虽然xx不是占位符,但您仍然可以通过
tf.cond()
在这两种情况下,bb
为False
,zz
的计算理论上不依赖于xx
,但tensorflow仍然需要输入xx
。尽管它可以作为虚拟数组提供,但它必须与yy
的形状相匹配,并且不像dict2
那样干净
有谁能建议如何在不为
xx
提供值的情况下计算zz
(使用tf.cond()
或任何其他方法)?您可以将xx
定义为tf.Variable
,并给它一个默认值(只要xx
没有输入另一个值,就会使用该值)。需要注意的几件事:
xx
不是占位符,但您仍然可以通过feed\u dict
将值输入到占位符中,从而将其视为占位符validate\u shape=False
以便可以将任何形状输入xx
trainable=False
,使xx
未被优化(否则,优化器可能会将其默认值更改为Nan
,这可能会导致问题)xx
的值,例如使用tf.global\u variables\u initializer()
import tensorflow as tf
import numpy as np
bb = tf.placeholder(tf.bool)
xx = tf.Variable(initial_value=0.0,validate_shape=False,trainable=False,name='xx')
yy = tf.placeholder(tf.float32, name='yy')
zz = tf.cond(bb, lambda: xx + yy, lambda: 100 + yy)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
dict1 = {bb:False, yy:np.array([1., 3, 4]), xx:np.array([5., 6, 7])}
print(sess.run(zz, feed_dict=dict1))
dict2 = {bb:False, yy:np.array([1., 3, 4])}
print(sess.run(zz, feed_dict=dict2))
import tensorflow as tf
import numpy as np
bb = tf.placeholder(tf.bool)
xx = tf.Variable(initial_value=0.0,validate_shape=False,trainable=False,name='xx')
yy = tf.placeholder(tf.float32, name='yy')
zz = tf.cond(bb, lambda: xx + yy, lambda: 100 + yy)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
dict1 = {bb:False, yy:np.array([1., 3, 4]), xx:np.array([5., 6, 7])}
print(sess.run(zz, feed_dict=dict1))
dict2 = {bb:False, yy:np.array([1., 3, 4])}
print(sess.run(zz, feed_dict=dict2))