Python 验证丢失满足特定标准时提前停止
我正在Keras中训练一个神经网络模型。我希望监控验证丢失,并在达到特定条件时停止培训 我知道,在给定的Python 验证丢失满足特定标准时提前停止,python,machine-learning,keras,deep-learning,Python,Machine Learning,Keras,Deep Learning,我正在Keras中训练一个神经网络模型。我希望监控验证丢失,并在达到特定条件时停止培训 我知道,在给定的耐心轮数下,如果训练没有改善,我可以使用提前停止停止训练 我想做些不同的事情。当val_loss在n轮后高于某个值时,我想停止训练,比如x 为了清楚起见,我们假设0.5中的x,n是50。仅当epoch数字大于50且valu loss大于0.5时,我才想停止模型的训练 如何在Keras中执行此操作?您可以通过继承Keras早期停止回调并使用自己的逻辑覆盖它来定义自己的回调: from keras
耐心
轮数下,如果训练没有改善,我可以使用提前停止
停止训练
我想做些不同的事情。当val_loss
在n
轮后高于某个值时,我想停止训练,比如x
为了清楚起见,我们假设0.5中的x
,n
是50
。仅当epoch
数字大于50
且valu loss
大于0.5
时,我才想停止模型的训练
如何在Keras中执行此操作?您可以通过继承Keras早期停止
回调并使用自己的逻辑覆盖它来定义自己的回调:
from keras.callbacks import EarlyStopping # use as base class
class MyCallBack(EarlyStopping):
def __init__(self, threshold, min_epochs, **kwargs):
super(MyCallBack, self).__init__(**kwargs)
self.threshold = threshold # threshold for validation loss
self.min_epochs = min_epochs # min number of epochs to run
def on_epoch_end(self, epoch, logs=None):
current = logs.get(self.monitor)
if current is None:
warnings.warn(
'Early stopping conditioned on metric `%s` '
'which is not available. Available metrics are: %s' %
(self.monitor, ','.join(list(logs.keys()))), RuntimeWarning
)
return
# implement your own logic here
if (epoch >= self.min_epochs) & (current >= self.threshold):
self.stopped_epoch = epoch
self.model.stop_training = True
举例说明它应该起作用:
from keras.layers import Input, Dense
from keras.models import Model
import numpy as np
# Generate some random data
features = np.random.rand(100, 5)
labels = np.random.rand(100, 1)
validation_feat = np.random.rand(100, 5)
validation_labels = np.random.rand(100, 1)
# Define a simple model
input_layer = Input((5, ))
dense_layer = Dense(10)(input_layer)
output_layer = Dense(1)(dense_layer)
model = Model(inputs=input_layer, outputs=output_layer)
model.compile(loss='mse', optimizer='sgd')
# Fit with custom callback
callbacks = [MyCallBack(threshold=0.001, min_epochs=10, verbose=1)]
model.fit(features, labels, validation_data=(validation_feat, validation_labels), callbacks=callbacks, epochs=100)
不错。我试着自己实现这个。我犯了一些错误。我会参考这个