Python Tensorflow:识别多RNN中的最终状态

Python Tensorflow:识别多RNN中的最终状态,python,tensorflow,rnn,Python,Tensorflow,Rnn,我是TF新手,我正在尝试在NN中实现多个GRU单元。但是,我无法确定MultiRNN单元的最终状态 例如,当我使用以下代码时: num_units = [128, 128] tf.reset_default_graph() x = tf.placeholder(tf.int32, [None, 134]) y = tf.placeholder(tf.int32, [None]) embedding_matrix = tf.Variable(tf.random_uniform([153, 128

我是TF新手,我正在尝试在NN中实现多个GRU单元。但是,我无法确定MultiRNN单元的最终状态

例如,当我使用以下代码时:

num_units = [128, 128]
tf.reset_default_graph()
x = tf.placeholder(tf.int32, [None, 134])
y = tf.placeholder(tf.int32, [None]) 
embedding_matrix = tf.Variable(tf.random_uniform([153, 128], -1.0, 1.0))
embeddings = tf.nn.embedding_lookup(embedding_matrix, x) 
cells = [tf.contrib.rnn.GRUCell(num_units=n) for n in num_units]
cell_type = tf.contrib.rnn.MultiRNNCell(cells=cells, state_is_tuple=True)
cell_type = tf.contrib.rnn.DropoutWrapper(cell=cell_type, output_keep_prob=0.75)
_, (encoding, _) = tf.nn.dynamic_rnn(cell_type, embeddings, dtype=tf.float32)
最后一行代码的输出为:

(<tf.Tensor 'rnn/transpose_1:0' shape=(?, 134, 128) dtype=float32>, (<tf.Tensor 'rnn/while/Exit_3:0' shape=(?, 128) dtype=float32>, <tf.Tensor 'rnn/while/Exit_4:0' shape=(?, 128) dtype=float32>))
输出更令人困惑:

输出格式:(a、[b、c、d、e])


我不知道上面哪一个是最终的内存状态,我可以进一步处理它进行损耗计算和预测。

好的。我从中找到了答案,它是MultiRNN堆栈中最后一个单元的状态。以下代码将根据输入num_单位的尺寸轻松提取相应的尺寸:

num_units = [128, 128, 128, 128]
rnn_output, final_states = tf.nn.dynamic_rnn(cell_type, embeddings, dtype=tf.float32)
encoding = final_states[len(num_units)-1] # For GRU & RNN
encoding = final_states[len(num_units)-1][0] # For LSTM
num_units = [128, 128, 128, 128]
rnn_output, final_states = tf.nn.dynamic_rnn(cell_type, embeddings, dtype=tf.float32)
encoding = final_states[len(num_units)-1] # For GRU & RNN
encoding = final_states[len(num_units)-1][0] # For LSTM