gpt4 book ai didi

python - ScikitLearn 中关于 epoch 的学习曲线解读

转载 作者:行者123 更新时间:2023-11-30 09:44:25 28 4
gpt4 key购买 nike

我是机器学习新手,目前正在使用 ScikitLearn 的 MLPClassifier 来执行神经网络任务。根据 Andrew Ng 著名的机器学习类(class),我正在绘制学习曲线,在我的例子中,使用 ScikitLearn 的函数learning_curve(另请参阅文档:https://scikit-learn.org/stable/auto_examples/model_selection/plot_learning_curve.html):

clf = MLPClassifier(solver='adam', activation='relu', alpha=0.001,
learning_rate='constant',learning_rate_init=0.0001,
hidden_layer_sizes=[39, 37, 31, 34], batch_size=200,
max_iter=1000, verbose=True)


cv=GroupKFold(n_splits=8)

estimator =clf
ylim=(0.7, 1.01)
cv=cv
n_jobs=1
train_sizes=np.linspace(.01, 1.0, 100)


#Calculate learning curve
train_sizes, train_scores, test_scores = learning_curve(
estimator, X_array_train, Y_array_train,
groups=groups_array_train, cv=cv, n_jobs=n_jobs,
train_sizes=train_sizes, scoring='accuracy',verbose=10)

我的 MLPClassifier 求解器是“adam”,批量大小为 200。

这是结果图: https://i.imgur.com/jDNoEVg.png

关于此类学习曲线的解释,我有两个问题:

1.) 据我了解这条学习曲线,它为我提供了不同数量的训练数据的训练和交叉验证分数,直到一个时期结束(时期=一次前向传递和一次反向传递)所有训练示例)。看看这两者之间的“差距”以及它们最终的得分,我可以诊断是否存在高偏差或方差问题。然而,根据我的 MLPClassifier 的详细信息,神经网络正在多个时期进行训练,因此曲线中给出了哪个时期(训练的第一个时期,最后一个时期还是所有时期的平均分数?) 。或者我对时代有什么误解?

2.) 开始一个新批处理(在 200 和 400 个训练示例之后),我得到了峰值。解释它们的正确方法是什么?

3.) 可能理解 1.) 也会回答这个问题:是什么让这个函数如此缓慢,以至于你需要几个并行作业 n_jobs 才能在合理的时间内完成它? clf.fit(X,y) 在我的情况下很快。

如果有人能帮助我更好地理解这一点,我将非常感激。我也愿意接受文献推荐。

非常感谢!

最佳答案

学习曲线只能在稳定的、可推广的模型上计算。您确保模型不会过度拟合吗?

1) 估计器被训练至完成,即训练至最终时期或任何早期停止阈值。这是多少取决于您的估算器配置。事实上,learning_curve 函数根本没有纪元的概念。它也可以应用于不使用纪元的分类器。

2) 与总样本数相比,您的批量大小非常大。考虑较小的批量大小,可能是 50 或 20。猜测:可能对于 201 个 sample ,您最终会得到一批 200 个和一批 1。那批 1 可能会导致问题。

3) 学习曲线将为每个训练样本选择的每个交叉验证折叠进行训练。就您而言,您似乎正在测试所有 500 种可能的训练规模。如果 CV 为 5 倍,则将进行 2500 轮训练。如果没有并行化,这需要 1 fit()+predict() 的 2500 倍。相反,您应该只对一些训练集大小进行采样。 train_sizes = numpy.linspace(0.0, 1.0, 30) 用于数据 0% 到 100% 之间的 30 个点。

关于python - ScikitLearn 中关于 epoch 的学习曲线解读,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54327491/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com