gpt4 book ai didi

pytorch - 为什么多层感知器在 CartPole 中优于 RNN?

转载 作者:行者123 更新时间:2023-12-04 18:33:38 24 4
gpt4 key购买 nike

最近,我比较了 CartPole-v0 环境中 DQN 的两个模型。其中一个是具有 3 层的多层感知器,另一个是由 LSTM 和 1 个全连接层构建的 RNN。我有一个大小为 200000 的经验回放缓冲区,并且在填满之前不会开始训练。尽管 MLP 在合理的训练步数下解决了问题(这意味着在最后 100 集内获得 195 的平均奖励),但 RNN 模型无法快速收敛,其最大平均奖励甚至没有达到 195!

我已经尝试增加批量大小,向 LSTM 的隐藏状态添加更多神经元,增加 RNN 的序列长度并使全连接层更加复杂 - 但每次尝试都失败了,因为我看到平均奖励的巨大波动,所以模型几乎没有收敛。这些可能是早期过度拟合的征兆吗?

class DQN(nn.Module):
def __init__(self, n_input, output_size, n_hidden, n_layers, dropout=0.3):
super(DQN, self).__init__()

self.n_layers = n_layers
self.n_hidden = n_hidden

self.lstm = nn.LSTM(input_size=n_input,
hidden_size=n_hidden,
num_layers=n_layers,
dropout=dropout,
batch_first=True)

self.dropout= nn.Dropout(dropout)

self.fully_connected = nn.Linear(n_hidden, output_size)

def forward(self, x, hidden_parameters):
batch_size = x.size(0)

output, hidden_state = self.lstm(x.float(), hidden_parameters)

seq_length = output.shape[1]

output1 = output.contiguous().view(-1, self.n_hidden)
output2 = self.dropout(output1)
output3 = self.fully_connected(output2)

new = output3.view(batch_size, seq_length, -1)
new = new[:, -1]

return new.float(), hidden_state

def init_hidden(self, batch_size, device):
weight = next(self.parameters()).data

hidden = (weight.new(self.n_layers, batch_size, self.n_hidden).zero_().to(device),
weight.new(self.n_layers, batch_size, self.n_hidden).zero_().to(device))

return hidden

与我的预期相反,更简单的模型给出了比另一个更好的结果;尽管 RNN 理应在处理时间序列数据方面表现更好。

谁能告诉我这是什么原因?

此外,我必须声明我没有应用任何特征工程,并且两个 DQN 都使用原始数据。 RNN 能否在使用归一化特征方面胜过 MLP? (我的意思是为两个模型提供标准化数据)

你有什么可以推荐给我的,以提高 RNN 的训练效率以达到最佳效果吗?

最佳答案

Contrary to what I expected the simpler model gave much better result that the other; even though RNN's supposed to be better in processing time series data.

车杆中没有时间序列,状态包含了最优决策所需的所有信息。例如,如果您要从图像中学习并且需要从一系列图像中估计极速度,情况就会有所不同。

另外,并不是说越复杂的模型就应该表现得越好。反之,更容易过拟合。对于车杆,你甚至不需要神经网络,一个简单的线性逼近器与 RBF 或随机傅立叶特征就足够了。对于这样一个简单的问题,RNN + LSTM 肯定是矫枉过正了。

关于pytorch - 为什么多层感知器在 CartPole 中优于 RNN?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56330269/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com