gpt4 book ai didi

python - Scikit-learn 管道中的 Keras 模型与早期停止

转载 作者:行者123 更新时间:2023-12-04 13:30:31 24 4
gpt4 key购买 nike

这个问题在这里已经有了答案:





Can I send callbacks to a KerasClassifier?

(4 个回答)


11 个月前关闭。




我正在训练一个 Keras 模型,该模型位于带有一些预处理的 Scikit 管道中。 Keras 模型定义为

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Input, Dropout
from tensorflow.keras import optimizers
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.wrappers.scikit_learn import KerasRegressor
from sklearn.pipeline import make_pipeline


def create_model(X_train):
inp = Input(shape=(X_train.shape[1],))
x = Dense(150, activation="relu")(inp)
x = Dropout(0.4)(x)
mean = Dense(1, activation="linear")(x)
train_model_1 = Model(inp, mean)
adam = optimizers.Adam(lr=0.01)
train_model_1.compile(loss=my_loss_function, optimizer=adam)
return train_model_1


clf = KerasRegressor(build_fn=create_model, epochs=250, batch_size=64)
然后在 Pipeline 中使用和
pipeline = make_pipeline(
other_steps,
clf(X_train)
)


pipeline.fit(X_train, y_train)
我想用 EarlyStopping其中测试数据 ( X_test, y_test ) 用于验证。这通常很简单
callbacks=[EarlyStopping(monitor='val_loss', patience=5)]

train_model_1.fit(X_train, y_train,
validation_data=(X_test, y_test),
callbacks=callbacks,
)
但我无法弄清楚这将在管道中的何处进行。构建它的正确方法是什么?

最佳答案

Pipeline.fit 有一个关键字参数参数:

**fit_params : dict of string -> object

Parameters passed to the fit method of each step, where each parameter name is prefixed such that parameter p for step s has key s__p.


所以像 pipeline.fit(x_train, y_train, kerasregressor__callbacks=callbacks)应该管用。 (检查管道步骤的名称,例如使用 pipeline.stepsmake_pipeline 使用类的小写名称生成名称,但我不确定这是否适用于 keras 。)
另见 How to pass a parameter to only one part of a pipeline object in scikit learn?

关于python - Scikit-learn 管道中的 Keras 模型与早期停止,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65372223/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com