gpt4 book ai didi

java - Encog - 使用交叉验证方法训练神经网络并验证训练

转载 作者:塔克拉玛干 更新时间:2023-11-02 19:53:42 24 4
gpt4 key购买 nike

一旦我看到从验证集计算出的误差开始增加,我就想停止训练网络。我正在使用带有 RPROP 的 BasicNetwork 作为训练算法,并且我有以下训练迭代:

void trainCrossValidation(BasicNetwork network, MLDataSet training, MLDataSet validation) {

FoldedDataSet folded = new FoldedDataSet(training);
Train train = new ResilientPropagation(network, folded);
CrossValidationKFold trainFolded = new CrossValidationKFold(train, KFOLDS);
trainFolded.addStrategy(new SimpleEarlyStoppingStrategy(validation));


int epoch = 1;
do {
trainFolded.iteration();
logger.debug("Iter. " + epoch + ": Erro = " + trainFolded.getError());
epoch++;
} while (!trainFolded.isTrainingDone() && epoch < MAX_ITERATIONS);

trainFolded.finishTraining();
}

不幸的是,它没有按预期工作。该方法需要很长时间才能执行,而且似乎不会在正确的时刻停止。我希望训练在验证错误开始增长的那一刻就中止,也就是说,在理想的训练迭代次数中。

有没有一种方法可以直接从折叠的交叉验证中提取验证数据,而不是创建另一个专门用于验证的MLDataSet?如果是,该怎么做?

我应该使用哪个参数来停止训练?你能告诉我必要的修改来做预期的事情吗?我怎样才能同时使用交叉验证和 SimpleEarlyStoppingStrategy?我很困惑

非常感谢您的帮助。

最佳答案

我认为那里有几个混淆点。

  • 一件事是当错误在不同的(!)数据集中开始增加时停止训练。这组数据通常称为验证数据集。这对每次训练都是有效的,即对于达到最大纪元数的每个循环(您的 do-while 循环)。为此,您需要跟踪上一次迭代的测量误差:在验证数据集上运行网络,获取误差并与下一个时期的误差进行比较。

    <
  • 另一件事是交叉验证。在这里,您多次训练网络。从整个训练过程中,您估计网络的优度。这是一种更复杂、更稳健的方法,具有不同的变体。我喜欢this diagram .

  • 最后:您在误差开始增加时停止训练并不意味着您找到了理想的训练量 时期。您可能会陷入局部最小值,这是这些模型中的常见问题。

希望对你有所帮助:)

关于java - Encog - 使用交叉验证方法训练神经网络并验证训练,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/37470649/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com