Python 如何在tensorflow2中恢复特定的检查点(以实现早期停止)?

Python 如何在tensorflow2中恢复特定的检查点(以实现早期停止)?,python,tensorflow,tensorflow2.0,checkpointing,Python,Tensorflow,Tensorflow2.0,Checkpointing,我使用以下代码在我训练模型的循环之外创建了一个检查点管理器: checkpoint_path = "./checkpoints/train" ckpt = tf.train.Checkpoint(object_1=object_1) ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=1) 然后,在训练模型时,我使用ckpt\u save\u path=ckpt\u

我使用以下代码在我训练模型的循环之外创建了一个检查点管理器:

checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(object_1=object_1)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=1)

然后,在训练模型时,我使用
ckpt\u save\u path=ckpt\u manager.save()
在每个历元之后保存变量

考虑到我想要实现一种提前停止的方法,我需要在一个特定的历元之后恢复所有变量,并使用这些变量进行预测。如果我使用了上面的代码保存变量(希望保存过程正确),那么在e纪元之后如何恢复变量。我知道我可以先创建相同的变量和对象,然后使用下面的代码来恢复最新的检查点,但不知道如何恢复特定的检查点(如纪元编号e之后的变量),而不是最新的检查点

ckpt.restore(ckpt\u manager.latest\u checkpoint).assert\u consumered()


谢谢,

是的,您需要生成带有历元编号的文件名文本字符串

c_manager = tf.train.CheckpointManager(checkpoint, ...)

if EPOCH == '':
    if c_manager.latest_checkpoint:
        tf.print("-----------Restoring from {}-----------".format(
            c_manager.latest_checkpoint))
        checkpoint.restore(c_manager.latest_checkpoint)
        EPOCH = c_manager.latest_checkpoint.split(sep='ckpt-')[-1]
    else:
        tf.print("-----------Initializing from scratch-----------")
else:    
    checkpoint_fname = CHECKPOINT_SAVE_DIR + 'ckpt-' + str(EPOCH)
    tf.print("-----------Restoring from {}-----------".format(checkpoint_fname))
    checkpoint.restore(checkpoint_fname)

谢谢我认为另一件值得补充的事情是,在创建ckpt\U管理器以在训练期间保存模型时,我应该使用
ckpt\u管理器=tf.train.CheckpointManager(ckpt,checkpoint\u path,max\u to\u keep=EPOCHS)
而不是``
ckpt\u管理器=tf.train.CheckpointManager(ckpt,checkpoint\u path,max\u to\u keep=1)
。注意,EPOCHS是训练中的历元总数,历元是验证损失最小的历元。