gpt4 book ai didi

deep-learning - PyTorch 中带有 Sequential 模块的简单 LSTM

转载 作者:行者123 更新时间:2023-12-04 21:06:43 24 4
gpt4 key购买 nike

在 PyTorch 中,我们可以通过多种方式定义架构。在这里,我想使用 Sequential 创建一个简单的 LSTM 网络。模块。

在 Lua 的火炬中,我通常会选择:

model = nn.Sequential()
model:add(nn.SplitTable(1,2))
model:add(nn.Sequencer(nn.LSTM(inputSize, hiddenSize)))
model:add(nn.SelectTable(-1)) -- last step of output sequence
model:add(nn.Linear(hiddenSize, classes_n))

但是,在 PyTorch 中,我找不到 SelectTable 的等价物。获得最后的输出。
nn.Sequential(
nn.LSTM(inputSize, hiddenSize, 1, batch_first=True),
# what to put here to retrieve last output of LSTM ?,
nn.Linear(hiddenSize, classe_n))

最佳答案

首先,我让 i 类提取最后一个单元格输出,如下所示

class extractlastcell(nn.Module):
def forward(self,x):
out , _ = x
return out[:, -1, :]
当我想在你的例子中使用它时,它会是这样的
nn.Sequential(
nn.LSTM(inputSize, hiddenSize, 1, batch_first=True),
extractlastcell(),
nn.Linear(hiddenSize, classe_n))

关于deep-learning - PyTorch 中带有 Sequential 模块的简单 LSTM,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44130851/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com