gpt4 book ai didi

python - Keras 中的检查点深度学习模型

转载 作者:太空宇宙 更新时间:2023-11-03 15:43:45 25 4
gpt4 key购买 nike

我需要帮助在 Keras 中实现检查点功能。我要训练一个大型数据集,因此为了做到这一点,我首先使用鸢尾花数据集训练了一个模型:http://machinelearningmastery.com/multi-class-classification-tutorial-keras-deep-learning-library/

因为我自己的数据集与它非常相似,唯一的区别是我的数据集更大。

对于检查点功能:http://machinelearningmastery.com/check-point-deep-learning-models-keras/

我理解使用 pima-indians 数据集的示例。现在我尝试在 iris-flower 脚本中实现相同的检查点功能。这是我迄今为止尝试过的。

import numpy
from pandas import *
from keras.models import Sequential
from keras.layers import Dense
from keras.wrappers.scikit_learn import KerasClassifier
from keras.utils import np_utils
from sklearn.model_selection import cross_val_score, KFold
from sklearn.preprocessing import LabelEncoder
from sklearn.pipeline import Pipeline
from keras.callbacks import ModelCheckpoint

seed = 7
numpy.random.seed(seed)

dataframe = read_csv("iris.csv", header=None)
dataset = dataframe.values
X = dataset[:,0:4].astype(float)
Y = dataset[:,4]

# encode class value as integers
encoder = LabelEncoder()
encoder.fit(Y)
encoded_Y = encoder.transform(Y)
dummy_y = np_utils.to_categorical(encoded_Y)

def baseline_model():
model = Sequential()
model.add(Dense(4, input_dim=4, init='normal', activation='relu'))
model.add(Dense(3, init='normal', activation='sigmoid'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
return model

filepath="weights-improvement-{epoch:02d}-{val_acc:.2f}.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=True, mode='max')
callbacks_list = [checkpoint]

estimator = KerasClassifier(build_fn=baseline_model, validation_split=0.33, nb_epoch=200, batch_size=5, callbacks=callbacks_list, verbose=0)
kfold = KFold(n_splits=10, shuffle=True, random_state=seed)
results = cross_val_score(estimator, X, dummy_y, cv=kfold)
print("Baseline: %.2f%% (%.2f%%)" % (results.mean()*100, results.std()*100))

此脚本产生以下错误。我不知道如何解决这个问题,或者我在脚本中的安排是错误的。

RuntimeError: Cannot clone object <keras.wrappers.scikit_learn.KerasClassifier object at 0x10e120fd0>, as the constructor does not seem to set parameter callbacks

我希望有人能帮助我解决这个问题。谢谢。

最佳答案

我认为问题在于您的 baseline_model() 函数没有返回它正在创建的模型;它应该是这样的:

def baseline_model():
model = Sequential()
model.add(Dense(4, input_dim=4, init='normal', activation='relu'))
model.add(Dense(3, init='normal', activation='sigmoid'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
return model

关于python - Keras 中的检查点深度学习模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41937719/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com