Tensorflow:可以阻止tf.where的一个分支执行吗?

Tensorflow:可以阻止tf.where的一个分支执行吗?,tensorflow,Tensorflow,我正在进行编码器-解码器设置。我希望能够运行编码器一次,然后执行多个解码器运行。我提出的解决方案是向解码器提供一个TF条件节点(使用TF.where),该节点包含编码器的最终隐藏状态(在这种情况下,当我请求解码器输出时,TF将运行编码器),或者一个带有编码器存储结果的占位符(在这种情况下,理论上TF不需要运行编码器) 以下是守则的相关部分: encoder_state = tf.where(gen_math_ops.greater_equal(branching_points, 0), enco

我正在进行编码器-解码器设置。我希望能够运行编码器一次,然后执行多个解码器运行。我提出的解决方案是向解码器提供一个TF条件节点(使用TF.where),该节点包含编码器的最终隐藏状态(在这种情况下,当我请求解码器输出时,TF将运行编码器),或者一个带有编码器存储结果的占位符(在这种情况下,理论上TF不需要运行编码器)

以下是守则的相关部分:

encoder_state = tf.where(gen_math_ops.greater_equal(branching_points, 0), encoder_state,
                         rnn.static_rnn(encoder_cell, encoder_inputs, dtype=dtype)[1])
由于我没有从这个方法中得到加速,我很确定它不起作用,tf.where的两个分支每次都由tf运行,即使它只需要从占位符读取

有没有办法使用tf.where,使其不运行编码器?我已经看过该方法的描述,我不确定是否总是计算这两个分支,我在这个问题上看到了相互矛盾的信息

谢谢

当您想推迟执行其中一个分支,直到对谓词进行求值时,可以使用该函数

encoder_state = tf.cond(
    tf.greater_equal(branching_points, 0),
    lambda: encoder_state,
    lambda: tf.nn.static_rnn(encoder_cell, encoder_inputs, dtype=dtype)[1])

我尝试使用tf.cond创建一个模型并输入一个字典,但是tf.cond只接受一个输入,因此如果有多个分支点,这将不起作用。
我已经创建了解决方案,但它非常复杂,我希望看到更好的解决方案,特别是如果真的和假的计算代价很高,那么这个只会提高性能。 如果未选择true\u fn或false\u fn的分支(例如,如果在这些函数中使用tf.assign),则此解决方案也很有用

首先,我创建布尔张量:

branch_1 = tf.greater_equal(branching_points, 0)
branch_2 = tf.logical_not(branch_1)
然后我使用一个布尔掩码只执行来自分支的真实条件

result_1 = tf.boolean_mask(branch_1)
result_2 = tf.boolean_mask(branch_2)
最后,如果需要,可以形成一个张量。 如果顺序很重要,您可以使用
tf.where(tf.equal(branch_1,True))
tf.where(tf.equal(branch_2,True))
分别获取结果_1和结果_2的索引。然后你申请。
如果顺序不重要,您可以简单地使用

tf.cond仅对标量有效。所以,如果分支点是布尔值的张量,这就行不通了,在这种情况下,它看起来是