gpt4 book ai didi

java - 如何在 Encog 中暂停/序列化遗传算法?

转载 作者:搜寻专家 更新时间:2023-10-31 20:10:49 24 4
gpt4 key购买 nike

如何在 Encog 3.4(Github 中当前正在开发的版本)中暂停遗传算法?

我正在使用 Encog 的 Java 版本。

我正在尝试修改 Encog 附带的 Lunar 示例。我想暂停/序列化遗传算法,然后在稍后阶段继续/反序列化。

当我调用 train.pause(); 时,它只返回 null - 这从代码中很明显,因为该方法总是返回 null

我认为这会非常简单,因为在某些情况下我想训练一个神经网络,用它进行一些预测,然后在我获得更多数据后继续使用遗传算法进行训练,然后再继续使用更多预测 - 无需从头开始重新开始训练。

请注意,我不是要序列化或持久化神经网络,而是整个遗传算法。

最佳答案

并非 Encog 中的所有训练器都支持简单的暂停/恢复。如果他们不支持它,他们将返回 null,就像这个一样。遗传算法训练器比支持暂停/恢复的简单传播训练器复杂得多。要保存遗传算法的状态,您必须保存整个种群,以及评分函数(可以序列化也可以不序列化)。我修改了 Lunar Lander 示例,以向您展示如何保存/重新加载您的神经网络群体来执行此操作。

您可以看到它训练了 50 次迭代,然后往返(加载/保存)遗传算法,然后再训练 50 次。

package org.encog.examples.neural.lunar;

import java.io.File;
import java.io.IOException;

import org.encog.Encog;
import org.encog.engine.network.activation.ActivationTANH;
import org.encog.ml.MLMethod;
import org.encog.ml.MLResettable;
import org.encog.ml.MethodFactory;
import org.encog.ml.ea.population.Population;
import org.encog.ml.genetic.MLMethodGeneticAlgorithm;
import org.encog.ml.genetic.MLMethodGenomeFactory;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.pattern.FeedForwardPattern;
import org.encog.util.obj.SerializeObject;

public class LunarLander {

public static BasicNetwork createNetwork()
{
FeedForwardPattern pattern = new FeedForwardPattern();
pattern.setInputNeurons(3);
pattern.addHiddenLayer(50);
pattern.setOutputNeurons(1);
pattern.setActivationFunction(new ActivationTANH());
BasicNetwork network = (BasicNetwork)pattern.generate();
network.reset();
return network;
}

public static void saveMLMethodGeneticAlgorithm(String file, MLMethodGeneticAlgorithm ga ) throws IOException
{
ga.getGenetic().getPopulation().setGenomeFactory(null);
SerializeObject.save(new File(file),ga.getGenetic().getPopulation());
}

public static MLMethodGeneticAlgorithm loadMLMethodGeneticAlgorithm(String filename) throws ClassNotFoundException, IOException {
Population pop = (Population) SerializeObject.load(new File(filename));
pop.setGenomeFactory(new MLMethodGenomeFactory(new MethodFactory(){
@Override
public MLMethod factor() {
final BasicNetwork result = createNetwork();
((MLResettable)result).reset();
return result;
}},pop));

MLMethodGeneticAlgorithm result = new MLMethodGeneticAlgorithm(new MethodFactory(){
@Override
public MLMethod factor() {
return createNetwork();
}},new PilotScore(),1);

result.getGenetic().setPopulation(pop);

return result;
}


public static void main(String args[])
{
BasicNetwork network = createNetwork();

MLMethodGeneticAlgorithm train;


train = new MLMethodGeneticAlgorithm(new MethodFactory(){
@Override
public MLMethod factor() {
final BasicNetwork result = createNetwork();
((MLResettable)result).reset();
return result;
}},new PilotScore(),500);

try {
int epoch = 1;

for(int i=0;i<50;i++) {
train.iteration();
System.out
.println("Epoch #" + epoch + " Score:" + train.getError());
epoch++;
}
train.finishTraining();

// Round trip the GA and then train again
LunarLander.saveMLMethodGeneticAlgorithm("/Users/jeff/projects/trainer.bin",train);
train = LunarLander.loadMLMethodGeneticAlgorithm("/Users/jeff/projects/trainer.bin");

// Train again
for(int i=0;i<50;i++) {
train.iteration();
System.out
.println("Epoch #" + epoch + " Score:" + train.getError());
epoch++;
}
train.finishTraining();

} catch(IOException ex) {
ex.printStackTrace();
} catch (ClassNotFoundException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}

int epoch = 1;

for(int i=0;i<50;i++) {
train.iteration();
System.out
.println("Epoch #" + epoch + " Score:" + train.getError());
epoch++;
}
train.finishTraining();

System.out.println("\nHow the winning network landed:");
network = (BasicNetwork)train.getMethod();
NeuralPilot pilot = new NeuralPilot(network,true);
System.out.println(pilot.scorePilot());
Encog.getInstance().shutdown();
}
}

关于java - 如何在 Encog 中暂停/序列化遗传算法?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/29977979/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com