Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/tensorflow/5.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
如何在TensorFlow中为2个单独的输入共享LSTM单元?_Tensorflow_Neural Network - Fatal编程技术网

如何在TensorFlow中为2个单独的输入共享LSTM单元?

如何在TensorFlow中为2个单独的输入共享LSTM单元?,tensorflow,neural-network,Tensorflow,Neural Network,假设我有两个输入q和a,如何使这两个输入共享一个LSTM单元?下面是我的部分代码 def lstmnets(self, sequence, seq_len): seq_embeds = self.embeds(sequence) # lstm_cell = tf.contrib.rnn.BasicLSTMCell(self.hidden_size) lstm_cell = tf.nn.rnn_cell.LSTMCell(self.hidden_size) ini

假设我有两个输入
q
a
,如何使这两个输入共享一个
LSTM
单元?下面是我的部分代码

def lstmnets(self, sequence, seq_len):
    seq_embeds = self.embeds(sequence)

    # lstm_cell = tf.contrib.rnn.BasicLSTMCell(self.hidden_size)
    lstm_cell = tf.nn.rnn_cell.LSTMCell(self.hidden_size)
    init_state = lstm_cell.zero_state(self.batch_size, dtype=tf.float32)
    lstm_out, final_state = tf.nn.dynamic_rnn(lstm_cell, seq_embeds, initial_state=init_state, sequence_length=seq_len)
    return lstm_out

def inference(self, q, a, q_len, a_len):
    with tf.variable_scope('lstmnets') as scope:
        query_rep = self.lstmnets(q, q_len)
        scope.reuse_variables()
        title_rep = self.lstmnets(a, a_len)
但是对于这个代码,我的结构有两个堆叠的
LSTM
,如下图所示。我怎么能只使用一个
LSTM
?此外,如何初始化LSTM权重并将其添加到直方图中?到目前为止,我还没有找到这方面的相关教程。谢谢


您的代码似乎很好,因为它使用scope.reuse_variable()来共享LSTM权重。最好的检查方法是打印图形中的变量,并验证lstm_单元格是否只声明一次。因此,在推理函数中打印变量名称:

def inference(self, q, a, q_len, a_len):
  with tf.variable_scope('lstmnets') as scope:
    query_rep = self.lstmnets(q, q_len)
    scope.reuse_variables()
    title_rep = self.lstmnets(a, a_len)
  for v in tf.global_variables():
    print(v.name)

谢谢我会尝试一下。你知道我在哪里可以初始化LSTM单元格权重吗?我打印名称,似乎只有一个LSTM,因为参数是
lstmnets/embeddings:0 lstmnets/rnn/lstmu cell/kernel:0 lstmnets/rnn/lstmu cell/bias:0
,为什么图形有两个堆叠的rnn…或者堆叠的LSTM完全相同,并且它们之间的线表示
共享
?与CNN共享相比,这有点奇怪……是的,图中显示的堆叠LSTM应该是相同的。在那里也查一下他们的名字。我只是深入研究了图表。该行表示两个张量:
内核
偏差
。我认为它们是共享的
W
U
b