Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/tensorflow/5.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
在tensorflow中,如何将一个图与另一个具有相同结构的图赋值?_Tensorflow_Deep Learning_Reinforcement Learning - Fatal编程技术网

在tensorflow中,如何将一个图与另一个具有相同结构的图赋值?

在tensorflow中,如何将一个图与另一个具有相同结构的图赋值?,tensorflow,deep-learning,reinforcement-learning,Tensorflow,Deep Learning,Reinforcement Learning,我试图在tensorflow中实现DQN。在这里,我有一个目标网络和一个培训网络,它们彼此具有相同的结构。在每10000个训练步骤开始时,我想将值从检查点加载到目标网络和训练网络,然后停止目标网络。然而,我尝试了这些方法,但没有一种有效: 1,将两个网络放在一个图中。但是,每次加载它们时,我都不知道如何将训练网络部分的值分配给目标网络部分。它们保存在不同的值中,因为其中一个是“停止梯度” 2、使用tf.graph定义两个图,分别运行两个会话。但是,我无法将一个图的检查点加载到另一个图,即使它们具

我试图在tensorflow中实现DQN。在这里,我有一个目标网络和一个培训网络,它们彼此具有相同的结构。在每10000个训练步骤开始时,我想将值从检查点加载到目标网络和训练网络,然后停止目标网络。然而,我尝试了这些方法,但没有一种有效:

1,将两个网络放在一个图中。但是,每次加载它们时,我都不知道如何将训练网络部分的值分配给目标网络部分。它们保存在不同的值中,因为其中一个是“停止梯度”

2、使用tf.graph定义两个图,分别运行两个会话。但是,我无法将一个图的检查点加载到另一个图,即使它们具有相同的结构。毕竟,它们是两个不同的图形


有谁能给我一些建议吗?非常感谢

典型的方法是将所有内容放在一个图中,将两个网络放在两个名称范围内,然后为一个范围内的每个变量创建tf.assign ops到另一个范围,并使用tf.group构造最终的复制操作。让我们假设函数create_net构建单个网络

with tf.name_scope('main_network'):
  main_net = create_net()

with tf.name_scope('target_network):
  target_network = create_net()

main_variables = tf.get_collection(tf.GraphKeys.VARIABLES, scope='main_network') 
target_variables = tf.get_collection(tf.GraphKeys.VARIABLES, scope='target_network')

# I am assuming get_collection returns variables in the same order, please double
# check this is actually happening

assign_ops = []
for main_var, target_var in zip(main_variables, target_variables):
  assign_ops.append(tf.assign(target_var, tf.identity(main_var)))

copy_operation = tf.group(*assign_ops)
现在在session.run中执行copy_操作,应该将主要网络参数复制到目标网络。上述代码应被视为伪代码,而不是可以复制粘贴的代码