Tensorflow KERA在EPHO结束前评估验证数据

Tensorflow KERA在EPHO结束前评估验证数据,tensorflow,keras,tf.keras,Tensorflow,Keras,Tf.keras,我想用凯拉斯训练我的模特。我使用的是一个巨大的数据集,其中一个训练阶段有超过30000个步骤。我的问题是,在检查验证数据集上的模型改进之前,我不想等待一段时间。有没有办法让Keras每1000步对培训数据评估一次验证数据?我认为一种选择是使用回调,但是Keras有内置的解决方案吗 if train: log('Start training') history = model.fit(train_dataset, steps_per_ep

我想用凯拉斯训练我的模特。我使用的是一个巨大的数据集,其中一个训练阶段有超过30000个步骤。我的问题是,在检查验证数据集上的模型改进之前,我不想等待一段时间。有没有办法让Keras每1000步对培训数据评估一次验证数据?我认为一种选择是使用回调,但是Keras有内置的解决方案吗

if train:
    log('Start training')
    history = model.fit(train_dataset,
                      steps_per_epoch=train_steps,
                      epochs=50,
                      validation_data=val_dataset,
                      validation_steps=val_steps,
                      callbacks=[
                            keras.callbacks.EarlyStopping(
                                monitor='loss',
                                patience=10,
                                restore_best_weights=True,
                            ),
                            keras.callbacks.ModelCheckpoint(
                                filepath=f'model.h5',
                                monitor='val_loss',
                                save_best_only=True,
                                save_weights_only=True,
                            ),
                            keras.callbacks.ReduceLROnPlateau(
                                monitor = "val_loss", 
                                factor = 0.5, 
                                patience = 3, 
                                min_lr=0.001,
                            ),
                        ],
                )

使用内置回调,您无法做到这一点。您需要的是实现一个自定义回调

class MyCustomCallback(tf.keras.callbacks.Callback):

  def on_train_batch_begin(self, batch, logs=None):
    print('Training: batch {} begins at {}'.format(batch, datetime.datetime.now().time()))

  def on_train_batch_end(self, batch, logs=None):
    print('Training: batch {} ends at {}'.format(batch, datetime.datetime.now().time()))

  def on_test_batch_begin(self, batch, logs=None):
    print('Evaluating: batch {} begins at {}'.format(batch, datetime.datetime.now().time()))

  def on_test_batch_end(self, batch, logs=None):
    print('Evaluating: batch {} ends at {}'.format(batch, datetime.datetime.now().time()))
这来自TensorFlow文档

您可以覆盖\u train\u batch\u end()函数上的
,并且由于批处理参数是整数,您可以验证
batch%100==0
,然后根据需要验证
self.model.predict(val\u data)


请在这里查看我的答案:对如何重写自定义回调函数有一个很好的概述。请注意,在您的情况下,重要的是火车上的
批处理
而不是火车上的
批处理

我想避免编写自定义回调,因为我不知道它这么简单