对Tensorflow中的两个不同输入重用MultiRNNCell

对Tensorflow中的两个不同输入重用MultiRNNCell,tensorflow,lstm,tensorboard,rnn,Tensorflow,Lstm,Tensorboard,Rnn,我想要一个多层LSTM模型,在每个小批量中,它应该计算两个不同输入的输出,因为以后它们将被不同地使用 我试图通过以下方式实现这一点: with tf.name_scope('placeholders'): X = tf.placeholder(tf.float64, shape=[batch_size, max_length, dim]) Y = tf.placeholder(tf.float64, shape=[batch_size, max_length, dim])

我想要一个多层LSTM模型,在每个小批量中,它应该计算两个不同输入的输出,因为以后它们将被不同地使用

我试图通过以下方式实现这一点:

with tf.name_scope('placeholders'):
    X = tf.placeholder(tf.float64, shape=[batch_size, max_length, dim])
    Y = tf.placeholder(tf.float64, shape=[batch_size, max_length, dim])
    seq_length1 = tf.placeholder(tf.int32, [batch_size], name="len1")
    seq_length2 = tf.placeholder(tf.int32, [batch_size], name="len2")

with tf.variable_scope("model") as scope:
    layers = [
        tf.contrib.rnn.BasicLSTMCell(num_units=num, activation=tf.nn.relu, name="e_lstm")
        for num in neurons
    ]
    if training:    # apply dropout during training
        layers_e = [
            tf.contrib.rnn.DropoutWrapper(layer, input_keep_prob=keep_prob)
            for layer in layers
        ]
    multi_layer_cell = tf.contrib.rnn.MultiRNNCell(layers)
    _, states_s = tf.nn.dynamic_rnn(multi_layer_cell, X, dtype=tf.float64, sequence_length=seq_length1)  

    _, states_o = tf.nn.dynamic_rnn(multi_layer_cell, Y, dtype=tf.float64, sequence_length=seq_length2)
但在TensorBoard的可视化图形中,它实际上在模型范围内构建了两个不同的RNN,并且一个RNN的输出成为另一个RNN的输入,反之亦然,这不是期望的行为

有人能告诉我应该如何修改代码以获得所需的行为吗

谢谢。

添加两行:

with tf.variable_scope('rnn'):
    _, states_s = tf.nn.dynamic_rnn(multi_layer_cell, X, dtype=tf.float64, sequence_length=seq_length1)  
with tf.variable_scope('rnn', reuse=True):
    _, states_o = tf.nn.dynamic_rnn(multi_layer_cell, Y, dtype=tf.float64, sequence_length=seq_length2)
我认为下面的代码是一个更好的方法,但不确定,欢迎您的建议

with tf.variable_scope('rnn', reues=tf.AUTO_REUSE):
    _, states_s = tf.nn.dynamic_rnn(multi_layer_cell, X, dtype=tf.float64, sequence_length=seq_length1)  
    _, states_o = tf.nn.dynamic_rnn(multi_layer_cell, Y, dtype=tf.float64, sequence_length=seq_length2)