TensorFlow中的条件执行

TensorFlow中的条件执行,tensorflow,Tensorflow,如何根据条件选择执行图形的一部分 我的网络中有一部分只有在feed\u dict中提供占位符值时才能执行。如果未提供该值,则采用备用路径。我如何使用tensorflow来实现这一点 以下是我的代码的相关部分: sess.run(accuracy, feed_dict={inputs: mnist.test.images, outputs: mnist.test.labels}) N = tf.shape(outputs) cost = 0 if N > 0:

如何根据条件选择执行图形的一部分

我的网络中有一部分只有在
feed\u dict
中提供占位符值时才能执行。如果未提供该值,则采用备用路径。我如何使用tensorflow来实现这一点

以下是我的代码的相关部分:

sess.run(accuracy, feed_dict={inputs: mnist.test.images, outputs: mnist.test.labels})

N = tf.shape(outputs)
    cost = 0
    if N > 0:
        y_N = tf.slice(h_c, [0, 0], N)
        cross_entropy = tf.nn.softmax_cross_entropy_with_logits(y_N, outputs, name='xentropy')
        cost = tf.reduce_mean(cross_entropy, name='xentropy_mean')

在上面的代码中,我正在寻找一些东西来代替
if N>0:

这里有一个简单的例子,可以让您开始学习。它根据张量的形状执行图形的不同部分:

import tensorflow as tf

a = tf.Variable([[3.0, 3.0], [3.0, 3.0]])
b = tf.Variable([[1.0, 1.0], [2.0, 2.0]])
l = tf.shape(a)

add_op, sub_op = tf.add(a, b), tf.sub(a, b)

sess = tf.Session()
init = tf.initialize_all_variables()
sess.run(init)
t = sess.run(l)

print sess.run(sub_op if t[0] == 3 else add_op)

sess.close()

将3改为2,看看如何减去张量。如您所见,我启动了
add
sub
shape
的节点,然后在图中检查形状并运行特定部分。

Hrm。可能您需要的是tf.control\u flow\u ops.cond()

但这并没有导出到tf名称空间中,我在回答时没有检查这个接口的稳定性,但它在已发布的模型中使用,所以请使用它。:)

但是:由于在构建提要时实际上预先知道了所需的路径,因此也可以采用不同的方法通过模型调用单独的路径。执行此操作的标准方法是,例如,设置如下代码:

def model(input, n_greater_than):
  ... cleverness ...
  if n_greater_than:
     ... other cleverness...
  return tf.reduce_mean(input)


out1 = model(input, True)
out2 = model(input, False)
然后根据您将要运行计算时知道的情况,拉出out1或out2节点并设置提要。请记住,默认情况下,如果模型引用相同的变量(在model()funct之外创建它们),则基本上会有两条单独的路径通过

您可以在卷积mnist示例中看到此示例:

如果可以的话,我喜欢这样做,而不引入控制流依赖项