gpt4 book ai didi

python - 当验证损失满足某些标准时提前停止

转载 作者:太空宇宙 更新时间:2023-11-04 02:19:55 25 4
gpt4 key购买 nike

我正在 Keras 中训练神经网络模型。我想监控验证损失并在达到特定条件时停止训练。

我知道我可以使用 EarlyStopping 在给定的 patience 轮次训练没有改善时停止训练。

我想要一些不同的东西。我想在 n 轮后 val_loss 超过 x 值时停止训练。

为了清楚起见,假设 0.5 中的 xn50。我只想在 epoch 数大于 50val_loss 高于 0.5 时停止模型训练。

我如何在 Keras 中执行此操作?

最佳答案

您可以通过继承 Keras EarlyStopping 回调并使用您自己的逻辑覆盖它来定义您自己的回调:

from keras.callbacks import EarlyStopping # use as base class

class MyCallBack(EarlyStopping):
def __init__(self, threshold, min_epochs, **kwargs):
super(MyCallBack, self).__init__(**kwargs)
self.threshold = threshold # threshold for validation loss
self.min_epochs = min_epochs # min number of epochs to run

def on_epoch_end(self, epoch, logs=None):
current = logs.get(self.monitor)
if current is None:
warnings.warn(
'Early stopping conditioned on metric `%s` '
'which is not available. Available metrics are: %s' %
(self.monitor, ','.join(list(logs.keys()))), RuntimeWarning
)
return

# implement your own logic here
if (epoch >= self.min_epochs) & (current >= self.threshold):
self.stopped_epoch = epoch
self.model.stop_training = True

说明它应该工作的小例子:

from keras.layers import Input, Dense
from keras.models import Model
import numpy as np

# Generate some random data
features = np.random.rand(100, 5)
labels = np.random.rand(100, 1)

validation_feat = np.random.rand(100, 5)
validation_labels = np.random.rand(100, 1)

# Define a simple model
input_layer = Input((5, ))
dense_layer = Dense(10)(input_layer)
output_layer = Dense(1)(dense_layer)
model = Model(inputs=input_layer, outputs=output_layer)
model.compile(loss='mse', optimizer='sgd')

# Fit with custom callback
callbacks = [MyCallBack(threshold=0.001, min_epochs=10, verbose=1)]
model.fit(features, labels, validation_data=(validation_feat, validation_labels), callbacks=callbacks, epochs=100)

关于python - 当验证损失满足某些标准时提前停止,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51874695/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com