tf.keras如何保存ModelCheckPoint对象

tf.keras如何保存ModelCheckPoint对象,keras,callback,pickle,google-colaboratory,tf.keras,Keras,Callback,Pickle,Google Colaboratory,Tf.keras,ModelCheckpoint可用于根据特定的监控指标保存最佳模型。因此,它显然有关于存储在其对象中的最佳度量的信息。例如,如果您在GoogleColab上训练,您的实例可能会在没有警告的情况下被杀死,并且在长时间的训练后,您将丢失这些信息 我试图pickle ModelCheckpoint对象,但得到: TypeError: can't pickle _thread.lock objects 这样,当我把笔记本拿回来时,我可以重复使用这个相同的对象。有什么好办法吗?您可以尝试通过以下方式

ModelCheckpoint可用于根据特定的监控指标保存最佳模型。因此,它显然有关于存储在其对象中的最佳度量的信息。例如,如果您在GoogleColab上训练,您的实例可能会在没有警告的情况下被杀死,并且在长时间的训练后,您将丢失这些信息

我试图pickle ModelCheckpoint对象,但得到:

TypeError: can't pickle _thread.lock objects  
这样,当我把笔记本拿回来时,我可以重复使用这个相同的对象。有什么好办法吗?您可以尝试通过以下方式复制:

chkpt_cb = tf.keras.callbacks.ModelCheckpoint('model.{epoch:02d}-{val_loss:.4f}.h5',
                                              monitor='val_loss',
                                              verbose=1,
                                              save_best_only=True)

with open('chkpt_cb.pickle', 'w') as f:
  pickle.dump(chkpt_cb, f, protocol=pickle.HIGHEST_PROTOCOL)

我认为您可能误解了
ModelCheckpoint
对象的预期用途。这是一种在特定阶段的训练中定期调用的方法。特别是ModelCheckpoint回调在每个历元之后都会被调用(如果您保持默认的
period=1
),并将您的模型保存到磁盘中,保存在您指定给
filepath
参数的文件名中。模型的保存方式与描述的相同。然后,如果您想稍后加载该模型,可以执行以下操作

from keras.models import load_model
model = load_model('my_model.h5')
其他答案为从保存的模型继续培训提供了很好的指导和示例,例如:。重要的是,保存的H5文件存储了继续培训所需的有关模型的所有信息

正如中所建议的,您不应该使用pickle来序列化您的模型。只需使用“fit”函数注册ModelCheckpoint回调:

chkpt_cb = tf.keras.callbacks.ModelCheckpoint('model.{epoch:02d}-{val_loss:.4f}.h5',
                                              monitor='val_loss',
                                              verbose=1,
                                              save_best_only=True)
model.fit(x_train, y_train,
          epochs=100,
          steps_per_epoch=5000,
          callbacks=[chkpt_cb])

您的模型将保存在一个名为的H5文件中,并自动为您格式化历元编号和损耗值。例如,您为第五个历元保存的文件(丢失0.0023)看起来像
model.05-.0023.h5
,并且由于您设置了
save\u best\u only=True
,只有当您的丢失情况比以前保存的好时,模型才会被保存,这样您就不会用一堆不需要的模型文件污染目录。

如果回调对象不被pickle(由于线程问题,不可取),我可以改为pickle:

best = chkpt_cb.best
这将存储回调所看到的最佳监控指标,它是一个浮点数,您可以在下次对其进行pickle并重新加载,然后执行以下操作:

chkpt_cb.best = best   # if chkpt_cb is a brand new object you create when colab killed your session. 
这是我自己的设置:

# All paths should be on Google Drive, I omitted it here for simplicity.

chkpt_cb = tf.keras.callbacks.ModelCheckpoint(filepath='model.{epoch:02d}-{val_loss:.4f}.h5',
                                              monitor='val_loss',
                                              verbose=1,
                                              save_best_only=True)

if os.path.exists('chkpt_cb.best.pickle'):
  with open('chkpt_cb.best.pickle', 'rb') as f:
    best = pickle.load(f)
    chkpt_cb.best = best

def save_chkpt_cb():
  with open('chkpt_cb.best.pickle', 'wb') as f:
    pickle.dump(chkpt_cb.best, f, protocol=pickle.HIGHEST_PROTOCOL)

save_chkpt_cb_callback = tf.keras.callbacks.LambdaCallback(
    on_epoch_end=lambda epoch, logs: save_chkpt_cb()
)

history = model.fit_generator(generator=train_data_gen,
                          validation_data=dev_data_gen,
                          epochs=5,
                          callbacks=[chkpt_cb, save_chkpt_cb_callback])

因此,即使您的colab会话被终止,您仍然可以检索最后一个最佳指标,并将其告知您的新实例,并像往常一样继续培训。当您重新编译一个有状态的优化器时,这尤其有用,可能会导致损失/度量的回归,并且不想在最初的几个时期保存这些模型

你能发布你正在使用的当前代码块吗?ModelCheckpoint通常是一个回调,因此从您的描述中不清楚您是如何使用它的。@adamconkey我已经用代码更新了它以进行复制。这相当简单。我只想对回调对象进行pickle处理。根据错误,它一定与线程问题有关。我找到的快速ans:Pickle chkpt_cb.best,然后将其重新分配到新的检查点。我试过了,它很管用。是的,我知道这就是它应该被使用的方式。如果你使用了COLAB并在训练中被截断,你会发现,如果从头开始重新调用回调函数,那么最后一个最佳度量将被遗忘。因此,我试图找到回调对象可以持久保存在磁盘上的解决方案。如果您的笔记本会话处于活动状态,它肯定会在内存中运行。你可以运行多个fit(…),它仍然跟踪到目前为止最好的指标。最好的度量肯定存储在回调对象中。