在tensorflow中,当图形被修改时,如何使用;“监控培训会话”;是否仅恢复部分检查点?

在tensorflow中,当图形被修改时,如何使用;“监控培训会话”;是否仅恢复部分检查点?,tensorflow,Tensorflow,我的目的很简单也很清楚:在部分修改图形后,如何从以前日志的检查点文件恢复未更改的变量/参数?(最好使用MonitoredTrainingSession) 我从这里对代码进行测试: 在resnet_model.py的第116-118行中,原始代码(或图形)为: 在第一次培训之后,我获得了检查点文件。 然后我将代码修改为: with tf.variable_scope('logit_modified'): logits_modified = self._fully_connected('f

我的目的很简单也很清楚:在部分修改图形后,如何从以前日志的检查点文件恢复未更改的变量/参数?(最好使用MonitoredTrainingSession)

我从这里对代码进行测试:

在resnet_model.py的第116-118行中,原始代码(或图形)为:

在第一次培训之后,我获得了检查点文件。 然后我将代码修改为:

with tf.variable_scope('logit_modified'):
    logits_modified = self._fully_connected('fc_1',x, 48)
    #self.predictions = tf.nn.softmax(logits)    
with tf.variable_scope('logit_2'):
    logits_2 = self._fully_connected('fc_2', logits_modified, 
    self.hps.num_classes)
    self.predictions = tf.nn.softmax(logits_2)
with tf.variable_scope('costs'):
    xent = tf.nn.softmax_cross_entropy_with_logits(
    logits=logits_2, labels=self.labels)
    self.cost = tf.reduce_mean(xent, name='xent')
    self.cost += self._decay()
然后,我尝试使用最新的API tf.train.MonitoredTrainingSession来恢复在第一次训练中获得的检查点。我尝试过多种方法来做这件事,但没有一种有效

尝试1: 如果我在监控培训课程中不使用脚手架:

with tf.train.MonitoredTrainingSession(
    checkpoint_dir=FLAGS.log_root,
    #scaffold=scaffold,
    hooks=[logging_hook, _LearningRateSetterHook()],
    chief_only_hooks=[summary_hook],
    save_checkpoint_secs = 600,
    # Since we provide a SummarySaverHook, we need to disable default
    # SummarySaverHook. To do that we set save_summaries_steps to 0.
    save_summaries_steps=None,
    save_summaries_secs=None,
    config=tf.ConfigProto(allow_soft_placement=True),
    stop_grace_period_secs=120,
    log_step_count_steps=100) as mon_sess:
while not mon_sess.should_stop():
    mon_sess.run(_train_op)
错误消息如下:

2017-12-29 10:33:30.699061:W tensorflow/core/framework/op_kernel.cc:1192]未找到:在检查点中未找到密钥logit_modified/fc_1/偏差/动量

虽然会话似乎尝试根据修改后的图进行恢复,但不根据新图和上一个检查点文件中存在的变量进行恢复(换句话说,所有层都不包括最后2个)

尝试2: 受使用tf.train.Supervisor的转移学习代码的启发: ,来自第251行

首先,我修改了resnet_model.py中的代码,添加以下行:

self.variables_to_restore = tf.contrib.framework.get_variables_to_restore(
exclude=["logit_modified", "logit_2"])
然后,MonitoredTrainingSession中的脚手架将更改为:

saver = tf.train.Saver(variables_to_restore)
def restore_fn(sess):
    return saver.restore(sess, FLAGS.log_root)
scaffold = tf.train.Scaffold(saver=saver, init_fn = restore_fn)
不幸的是,显示了以下错误消息:

RuntimeError:初始化操作未使模型为本地初始化做好准备。Init op:group_deps,Init fn:at 0x7f0ec26f4320>,错误:变量未初始化:logit_modified/fc_1/DW

似乎最后两层未正确恢复,因此其余层未恢复

尝试3: 我也尝试了这里列出的方法,但没有一种有效

我知道还有其他方法可以恢复,比如中的代码,但它们是嵌套的,不够通用,无法轻松应用于其他模型。这就是我想使用“MonitoredTrainingSession”的原因


那么,如何使用“MonitoredTrainingSession”只恢复tensorflow中的部分检查点呢

好吧,我终于明白了

在此处读取监视的_session.py后: ,我发现关键点(也是非常棘手的)是更改为一个新的空检查点目录,这样MonitoredTrainingSession就不会忽略init_op或init_fn。 然后,您可以使用以下代码构建init_fn(以恢复检查点)以及scaffold:

variables_to_restore = tf.contrib.framework.get_variables_to_restore(
    exclude=['XXX'])    
init_assign_op, init_feed_dict = tf.contrib.framework.assign_from_checkpoint(
    ckpt.model_checkpoint_path, variables_to_restore)
def InitAssignFn(scaffold,sess):
    sess.run(init_assign_op, init_feed_dict)

scaffold = tf.train.Scaffold(saver=tf.train.Saver(), init_fn=InitAssignFn)
记住ckpt。上面的model\u checkpoint\u路径是您的旧检查点路径,其中包含预训练的文件。我上面提到的新的空检查点目录在这里表示MonitoredTrainingSession的参数“checkpoint_dir”:

with tf.train.MonitoredTrainingSession(
    checkpoint_dir=FLAGS.log_root_2,...) as mon_sess:
while not mon_sess.should_stop():
    mon_sess.run(_train_op)
我修改的代码的第一段来自tf.slim中的learning.py,来自第134行:

加上: 感谢本问答的灵感,尽管解决方案有点不同:

很抱歉,错误消息总是在StackOverflow上出现缩进问题。问题中说明了错误消息的主要思想。我发现在代码中使用新的保护程序加载预训练模型很重要。当我使用相同的保护程序加载预训练模型和保存检查点时,该保护程序将只保存加载的检查点中出现的操作,并忽略新添加的操作,即使没有指定变量\u to \u restore。
with tf.train.MonitoredTrainingSession(
    checkpoint_dir=FLAGS.log_root_2,...) as mon_sess:
while not mon_sess.should_stop():
    mon_sess.run(_train_op)