gpt4 book ai didi

java - RL4J A3C DeepLearning 从网络中抛出的输出不是概率分布

转载 作者:行者123 更新时间:2023-12-02 00:59:33 40 4
gpt4 key购买 nike

所以现在我正在痛苦地探索使用 Deep Learning 4j 特别是 RL4j 和强化学习的深度学习。我在教我的电脑如何玩贪吃蛇方面相对不成功,但我坚持了下来。

无论如何,我遇到了一个无法解决的问题,我会将程序设置为在我 sleep 或工作时运行(是的,我在一个重要行业工作),当我回来检查时在所有正在运行的线程上抛出此错误,并且程序已完全停止,请注意,这通常会在训练后大约一个小时内发生。

Exception in thread "Thread-8" java.lang.RuntimeException: Output from network is not a probability distribution: [[         ?,         ?,         ?]]
at org.deeplearning4j.rl4j.policy.ACPolicy.nextAction(ACPolicy.java:82)
at org.deeplearning4j.rl4j.policy.ACPolicy.nextAction(ACPolicy.java:37)
at org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete.trainSubEpoch(AsyncThreadDiscrete.java:96)
at org.deeplearning4j.rl4j.learning.async.AsyncThread.handleTraining(AsyncThread.java:144)
at org.deeplearning4j.rl4j.learning.async.AsyncThread.run(AsyncThread.java:121)

这是我如何设置网络

    private static A3CDiscrete.A3CConfiguration CARTPOLE_A3C =
new A3CDiscrete.A3CConfiguration(
(new java.util.Random()).nextInt(), //Random seed
220, //Max step By epoch
500000, //Max step
6, //Number of threads
50, //t_max
75, //num step noop warmup
0.1, //reward scaling
0.987, //gamma
1.0 //td-error clipping
);


private static final ActorCriticFactorySeparateStdDense.Configuration CARTPOLE_NET_A3C = ActorCriticFactorySeparateStdDense.Configuration
.builder().updater(new Adam(.005)).l2(.01).numHiddenNodes(32).numLayer(3).build();

此外,我的网络的输入是将我的贪吃蛇游戏 16x16 的整个网格放入单个双数组中。

万一它与我的奖励函数有关

if(!snake.inGame()) {
return -5.3; //snake dies
}
if(snake.gotApple()) {
return 5.0+.37*(snake.getLength()); //snake gets apple
}
return 0; //survives

我的问题是如何阻止此错误的发生?我真的不知道发生了什么,这让我的网络 build 变得相当困难,是的,我已经在网上检查过答案,所有的结果就像 2018 年的 2 张 GitHub 票一样。

如果您感兴趣,那么您不必去挖掘这里是 ACPolicy 中抛出错误的函数

 public Integer nextAction(INDArray input) {
INDArray output = actorCritic.outputAll(input)[1];
if (rnd == null) {
return Learning.getMaxAction(output);
}
float rVal = rnd.nextFloat();
for (int i = 0; i < output.length(); i++) {
//System.out.println(i + " " + rVal + " " + output.getFloat(i));
if (rVal < output.getFloat(i)) {
return i;
} else
rVal -= output.getFloat(i);
}

throw new RuntimeException("Output from network is not a probability distribution: " + output);
}

非常感谢您提供的任何帮助

最佳答案

您所看到的是您的网络正在运行 NaN。这就是异常(exception)中问号的含义。发生这种情况的原因有很多。您说,您已经运行了相当长一段时间,因此可能会在某些时候出现不足或溢出。一些正则化或一些梯度裁剪可能会有所帮助。

但是,从 beta6 开始,RL4J 本身正在重新设计,并且在下一个版本中应该处于更好的状态。

如果您想尝试当前状态,可以使用快照,并且在 https://github.com/RobAltena/cartpole/blob/master/src/main/java/A3CCartpole.java 处还有一个有效的 A3C 示例。

要获得更彻底的帮助,您可能应该查看 DL4J 社区论坛:community.konduit.ai。它更适合帮助您为贪吃蛇游戏构建成功的 AI 所需的来回操作。

关于java - RL4J A3C DeepLearning 从网络中抛出的输出不是概率分布,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60869267/

40 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com