Tensorflow restore`tf.Session`使用`tf.train.MonitoredTrainingSession`保存检查点`

Tensorflow restore`tf.Session`使用`tf.train.MonitoredTrainingSession`保存检查点`,session,tensorflow,checkpoint,Session,Tensorflow,Checkpoint,我有使用tf.train.MonitoredTrainingSession训练CNN的代码 当我创建一个新的tf.train.MonitoredTrainingSession时,我可以将checkpoint目录作为输入参数传递给会话,它将自动恢复它可以找到的最新保存的checkpoint。我可以设置挂钩进行训练,直到某个步骤。例如,如果检查点的步骤是150000,我想训练到200000,我会将最后一步设置为200000 只要使用tf.train.MonitoredTrainingSession保

我有使用
tf.train.MonitoredTrainingSession
训练CNN的代码

当我创建一个新的
tf.train.MonitoredTrainingSession
时,我可以将
checkpoint
目录作为输入参数传递给会话,它将自动恢复它可以找到的最新保存的
checkpoint
。我可以设置
挂钩
进行训练,直到某个步骤。例如,如果
检查点的步骤是
150000
,我想训练到
200000
,我会将
最后一步设置为
200000

只要使用
tf.train.MonitoredTrainingSession
保存了最新的
检查点,上述过程就可以正常工作。然而,如果我试图恢复一个使用正常的
tf.Session
保存的
检查点,那么所有的麻烦都会消失。它在图表中找不到一些键

培训是通过以下方式完成的:

with tf.train.MonitoredTrainingSession(
    checkpoint_dir=FLAGS.retrain_dir,
    hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_training_steps),
           tf.train.NanTensorHook(loss),
           _LoggerHook()],
    config=tf.ConfigProto(
        log_device_placement=FLAGS.log_device_placement)) as mon_sess:
  while not mon_sess.should_stop():
    mon_sess.run(train_op)
如果
checkpoint\u dir
属性有一个没有检查点的文件夹,则将从头开始。如果它有上一次培训课程中保存的
检查点
,它将恢复最新的
检查点
,并继续培训

现在,我正在恢复最新的
检查点
,并修改一些变量并保存它们:

saver = tf.train.Saver(variables_to_restore)

ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)

with tf.Session() as sess:
  if ckpt and ckpt.model_checkpoint_path:
    # Restores from checkpoint
    saver.restore(sess, ckpt.model_checkpoint_path)
    print(ckpt.model_checkpoint_path)
    restored_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
    FLAGS.max_training_steps = int(restored_step) + FLAGS.max_training_steps
  else:
    print('No checkpoint file found')
    return

  prune_convs(sess)
  saver.save(sess, FLAGS.retrain_dir+"model.ckpt-"+restored_step)
正如您所见,就在
saver.save…
之前,我正在修剪网络中的所有卷积层。不需要描述如何以及为什么这样做。关键是网络实际上已经被修改了。然后我将网络保存到
检查点

现在,如果我在保存的修改后的网络上部署测试,测试工作正常。但是,当我尝试在保存的
检查点上运行
tf.train.MonitoredTrainingSession
时,它会显示:

检查点中未找到conv1键/体重减轻/平均值

此外,我还注意到,使用
tf.Session
保存的
检查点
的大小是使用
tf.train.MonitoredTrainingSession
保存的
检查点
的一半


我知道我做错了,有什么建议可以让这一切顺利进行吗?

我明白了。显然,
tf.Saver
不会从
检查点还原所有变量。我尝试立即恢复和保存,输出的大小只有原来的一半

我使用
tf.train.list_变量
从最新的
检查点
获取所有变量,然后将它们转换为
tf.Variable
,并从中创建一个
dict
。然后我将
dict
传递给
tf.Saver
,它恢复了我所有的变量

下一步是初始化所有变量,然后修改权重

现在它正在工作