Tensorflow restore`tf.Session`使用`tf.train.MonitoredTrainingSession`保存检查点`
我有使用Tensorflow restore`tf.Session`使用`tf.train.MonitoredTrainingSession`保存检查点`,session,tensorflow,checkpoint,Session,Tensorflow,Checkpoint,我有使用tf.train.MonitoredTrainingSession训练CNN的代码 当我创建一个新的tf.train.MonitoredTrainingSession时,我可以将checkpoint目录作为输入参数传递给会话,它将自动恢复它可以找到的最新保存的checkpoint。我可以设置挂钩进行训练,直到某个步骤。例如,如果检查点的步骤是150000,我想训练到200000,我会将最后一步设置为200000 只要使用tf.train.MonitoredTrainingSession保
tf.train.MonitoredTrainingSession
训练CNN的代码
当我创建一个新的tf.train.MonitoredTrainingSession
时,我可以将checkpoint
目录作为输入参数传递给会话,它将自动恢复它可以找到的最新保存的checkpoint
。我可以设置挂钩
进行训练,直到某个步骤。例如,如果检查点的步骤是150000
,我想训练到200000
,我会将最后一步设置为200000
只要使用tf.train.MonitoredTrainingSession
保存了最新的检查点,上述过程就可以正常工作。然而,如果我试图恢复一个使用正常的tf.Session
保存的检查点,那么所有的麻烦都会消失。它在图表中找不到一些键
培训是通过以下方式完成的:
with tf.train.MonitoredTrainingSession(
checkpoint_dir=FLAGS.retrain_dir,
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_training_steps),
tf.train.NanTensorHook(loss),
_LoggerHook()],
config=tf.ConfigProto(
log_device_placement=FLAGS.log_device_placement)) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)
如果checkpoint\u dir
属性有一个没有检查点的文件夹,则将从头开始。如果它有上一次培训课程中保存的检查点
,它将恢复最新的检查点
,并继续培训
现在,我正在恢复最新的检查点
,并修改一些变量并保存它们:
saver = tf.train.Saver(variables_to_restore)
ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
with tf.Session() as sess:
if ckpt and ckpt.model_checkpoint_path:
# Restores from checkpoint
saver.restore(sess, ckpt.model_checkpoint_path)
print(ckpt.model_checkpoint_path)
restored_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
FLAGS.max_training_steps = int(restored_step) + FLAGS.max_training_steps
else:
print('No checkpoint file found')
return
prune_convs(sess)
saver.save(sess, FLAGS.retrain_dir+"model.ckpt-"+restored_step)
正如您所见,就在saver.save…
之前,我正在修剪网络中的所有卷积层。不需要描述如何以及为什么这样做。关键是网络实际上已经被修改了。然后我将网络保存到检查点
现在,如果我在保存的修改后的网络上部署测试,测试工作正常。但是,当我尝试在保存的检查点上运行tf.train.MonitoredTrainingSession
时,它会显示:
检查点中未找到conv1键/体重减轻/平均值
此外,我还注意到,使用tf.Session
保存的检查点
的大小是使用tf.train.MonitoredTrainingSession
保存的检查点
的一半
我知道我做错了,有什么建议可以让这一切顺利进行吗?我明白了。显然,tf.Saver
不会从检查点还原所有变量。我尝试立即恢复和保存,输出的大小只有原来的一半
我使用tf.train.list_变量
从最新的检查点
获取所有变量,然后将它们转换为tf.Variable
,并从中创建一个dict
。然后我将dict
传递给tf.Saver
,它恢复了我所有的变量
下一步是初始化所有变量,然后修改权重
现在它正在工作