Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/tensorflow/5.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python Tensorflow:如何从rnn_cell.BasicLSTM&;获取所有变量;多单元_Python_Tensorflow - Fatal编程技术网

Python Tensorflow:如何从rnn_cell.BasicLSTM&;获取所有变量;多单元

Python Tensorflow:如何从rnn_cell.BasicLSTM&;获取所有变量;多单元,python,tensorflow,Python,Tensorflow,我有一个设置,需要在使用tf.initialize\u all\u variables()的主初始化之后初始化LSTM。例如,我想调用tf.初始化变量([var\u list]) 是否有方法收集以下两方面的所有内部可培训变量: rnn_cell.BasicLSTM 多单元 这样我就可以初始化这些参数了 我想这样做的主要原因是我不想重新初始化以前的一些训练值。解决问题的最简单方法是使用变量范围。作用域内变量的名称将以其名称作为前缀。以下是一个简短的片段: cell = rnn_cell.Bas

我有一个设置,需要在使用
tf.initialize\u all\u variables()
的主初始化之后初始化LSTM。例如,我想调用
tf.初始化变量([var\u list])

是否有方法收集以下两方面的所有内部可培训变量:

  • rnn_cell.BasicLSTM
  • 多单元
这样我就可以初始化这些参数了


我想这样做的主要原因是我不想重新初始化以前的一些训练值。

解决问题的最简单方法是使用变量范围。作用域内变量的名称将以其名称作为前缀。以下是一个简短的片段:

cell = rnn_cell.BasicLSTMCell(num_nodes)

with tf.variable_scope("LSTM") as vs:
  # Execute the LSTM cell here in any way, for example:
  for i in range(num_steps):
    output[i], state = cell(input_data[i], state)

  # Retrieve just the LSTM variables.
  lstm_variables = [v for v in tf.all_variables()
                    if v.name.startswith(vs.name)]

# [..]
# Initialize the LSTM variables.
tf.initialize_variables(lstm_variables)
它与
MultiRNNCell
的工作方式相同

编辑:将
tf.可训练变量
更改为
tf.所有变量()

(部分抄袭拉法的答案)

注意,最后一行相当于Rafal代码中的列表理解

基本上,tensorflow存储变量的全局集合,可以通过
tf.all_variables()
tf.get_集合(tf.GraphKeys.variables)
获取。如果在函数中指定
作用域
(作用域名称),则仅获取作用域在指定作用域下的集合中的张量(本例中为变量)

编辑:
您还可以使用
tf.GraphKeys.TRAINABLE_VARIABLES
仅获取可训练变量。但由于vanilla BasicLSTMCell不会初始化任何不可训练的变量,因此两者在功能上是等效的。有关默认图形集合的完整列表,请查看。

这太完美了,谢谢。我没有意识到
tf.trainable_variables()
尊重范围,但事后看来这是有道理的!希望添加
tf.all\u variables()
而不是
tf.trainable\u variables()
将是更好的选择。主要是因为优化器没有可训练的变量,但是仍然需要初始化。您可能应该检查
v.name.startswith(vs.name+“/”)
,因为可能存在另一个具有相同名称前缀的作用域,例如“LSTM2”。它可以工作,但有时我希望LSTM的设置与TensorFlow中所有其他类型的变量保持一致。这是比Rafal的解决方案更好的方法:-)正如我上面所评论的,您可能最好使用
tf.get_collection(…,scope=vs.name+“/”)
,因为可能还有另一个作用域名为“LSTM2”左右。
cell = rnn_cell.BasicLSTMCell(num_nodes)
with tf.variable_scope("LSTM") as vs:
  # Execute the LSTM cell here in any way, for example:
  for i in range(num_steps):
    output[i], state = cell(input_data[i], state)

  lstm_variables = tf.get_collection(tf.GraphKeys.VARIABLES, scope=vs.name)