gpt4 book ai didi

python - Keras 自定义回调 : Grid point training termination condition

转载 作者:太空宇宙 更新时间:2023-11-04 04:35:13 31 4
gpt4 key购买 nike

我正在使用 scikit-learnGridSearchCV 在我的 keras 神经网络上进行网格搜索。我想自定义 callback,这样每次在一个网格点上完成网络训练时,我都可以打印出匹配完成。

假设我定义我的网格如下:

param_grid = dict(epochs=[50, 100, 500, 1000],
learn_rate=[0.1, 0.2, 0.3],
momentum=[0.01, 0.1],
dropout_rate=[0.05, 0.1, 0.15, 0.2])

我计算网格上的可能性总数为:

grid_size = reduce(lambda x,y: x*y,[len(param_grid_[key]) for key in param_grid])

回调是:

from keras.callbacks import ModelCheckpoint, EarlyStopping
# checkpoint
filepath="best_model.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor='val_acc', verbose=1,
save_best_only=True, mode='max')
# Early stoping
monitor = EarlyStopping(monitor='val_loss', min_delta=1e-5, patience=200,
verbose=1, mode='auto')

callbacks_list = [checkpoint, monitor, LiveGridReport()]

其中 LiveGridReport() 是我自定义的回调,它打印关于在网格点上完成训练的消息。

class LiveGridReport(keras.callbacks.Callback):

def __init__(self, grid_size):
grid_size_ = grid_size

def on_train_begin(self, logs={}):
return

def on_train_end(self, logs={}):
return

我的问题是,考虑到我还有 EarlyStopping 回调,我不知道如何检测网格点上的训练已终止。

最佳答案

确定在使用 EarlyStopping 回调时停止训练的时期可以使用 stopped_epoch 来完成

EarlyStopping.stopped_epoch

或使用历史记录

history = model.fit(....)
number_of_epochs_it_ran = len(history.history['loss'])

关于python - Keras 自定义回调 : Grid point training termination condition,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51890070/

31 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com