Tensorflow TF Keras模型检查点文件路径批次号
我使用Tensorflow TF Keras模型检查点文件路径批次号,tensorflow,keras,checkpointing,Tensorflow,Keras,Checkpointing,我使用ModelCheckpoint在每个历元中每500批保存一次检查点。这里有记录 如何设置filepath以包含批号?我知道我可以使用{epoch}和日志中的参数 这可能会有所帮助,但问题并不清楚。在回调类下,有许多函数满足您的需求 示例代码 class WeightsSaver(Callback): def __init__(self, N): self.N = N self.epoch = 0 def on_epoch_end(self, epoch, logs
ModelCheckpoint
在每个历元中每500批保存一次检查点。这里有记录
如何设置
filepath
以包含批号?我知道我可以使用{epoch}
和日志中的参数 这可能会有所帮助,但问题并不清楚。在回调类下,有许多函数满足您的需求
示例代码
class WeightsSaver(Callback):
def __init__(self, N):
self.N = N
self.epoch = 0
def on_epoch_end(self, epoch, logs={}):
if self.epoch % self.N == 0:
name = ('weights%04d.hdf5') % self.epoch
self.model.save_weights(name)
self.epoch += 1
callbacks_list = [WeightsSaver(10)] #save every 10 models
model.fit(train_X,train_Y,epochs=n_epochs,callbacks=callbacks_list)
假设您将tf.keras.callbacks.ModelCheckpoint
与save\u freq=int
一起使用(在一定数量的批处理后需要保存),您可以创建一个继承自ModelCheckpoint
的类,并在批处理结束时修改类方法
class CustomCallback(tf.keras.callbacks.ModelCheckpoint):
def __init__(self, filepath, save_freq):
self.model_name = filepath
self.save_freq = save_freq
super().__init__(self.model_name, save_freq=self.save_freq)
def on_train_batch_end(self, batch, logs=None):
if self._should_save_on_batch(batch):
filename = self.model_name + "epoch_" + str(self._current_epoch+1) + "_batch_ " + str(batch+1) + '.tf'
self.model.save_weights(filename)
print("\nsaved checkpoint: " + filename + "\n")
然后在model.fit中添加此类的实例
SAVE_FREQ = 200 # number of batches
custom_callback = CustomCallback(filepath="checkpoint_", save_freq=SAVE_FREQ)
model.fit(..., callbacks=[custom_callback])
这将向检查点文件名添加历元和批号
Epoch 1/3
199/422 [=============>................] - ETA: 6s - loss: 0.0261 - accuracy: 0.9915
saved checkpoint: checkpoint_epoch_0_batch_200.tf
399/422 [===========================>..] - ETA: 0s - loss: 0.0263 - accuracy: 0.9914
saved checkpoint: checkpoint_epoch_0_batch_400.tf
422/422 [==============================] - 13s 31ms/step - loss: 0.0264 - accuracy: 0.9914 - val_loss: 0.0311 - val_accuracy: 0.9920
Epoch 2/3
177/422 [===========>..................] - ETA: 7s - loss: 0.0254 - accuracy: 0.9913
saved checkpoint: checkpoint_epoch_1_batch_178.tf
377/422 [=========================>....] - ETA: 1s - loss: 0.0252 - accuracy: 0.9912
saved checkpoint: checkpoint_epoch_1_batch_378.tf
422/422 [==============================] - 13s 32ms/step - loss: 0.0252 - accuracy: 0.9912 - val_loss: 0.0306 - val_accuracy: 0.9925
Epoch 3/3
156/422 [==========>...................] - ETA: 7s - loss: 0.0253 - accuracy: 0.9914
saved checkpoint: checkpoint_epoch_2_batch_156.tf
355/422 [========================>.....] - ETA: 2s - loss: 0.0246 - accuracy: 0.9919
saved checkpoint: checkpoint_epoch_2_batch_356.tf
422/422 [==============================] - 13s 31ms/step - loss: 0.0245 - accuracy: 0.9919 - val_loss: 0.0294 - val_accuracy: 0.9922