gpt4 book ai didi

python - 如何使用 PyTorch 为堆叠式 LSTM 模型执行 return_sequences?

转载 作者:行者123 更新时间:2023-12-04 02:38:42 24 4
gpt4 key购买 nike

我有一个 Tensorflow/Keras 模型:


self.model.add(Bidirectional(LSTM(lstm1_size, input_shape=(
seq_length, feature_dim), return_sequences=True)))
self.model.add(BatchNormalization())
self.model.add(Dropout(0.2))

self.model.add(Bidirectional(
LSTM(lstm2_size, return_sequences=True)))
self.model.add(BatchNormalization())
self.model.add(Dropout(0.2))

# BOTTLENECK HERE

self.model.add(Bidirectional(
LSTM(lstm3_size, return_sequences=True)))
self.model.add(BatchNormalization())
self.model.add(Dropout(0.2))

self.model.add(Bidirectional(
LSTM(lstm4_size, return_sequences=True)))
self.model.add(BatchNormalization())
self.model.add(Dropout(0.2))

self.model.add(Bidirectional(
LSTM(lstm5_size, return_sequences=True)))
self.model.add(BatchNormalization())
self.model.add(Dropout(0.2))

self.model.add(Dense(feature_dim, activation='linear'))

如何使用 return_sequences 创建堆叠式 PyTorch 模型?我对 return_sequences 的理解是,它返回 LSTM 每一层的“输出”,然后将其馈送到下一层。

我如何使用 PyToch 完成此操作?

最佳答案

PyTorch 总是返回序列。

https://pytorch.org/docs/stable/nn.html#lstm

enter image description here

示例:

import torch as t

batch_size = 2
time_steps = 10
features = 2
data = t.empty(batch_size, time_steps, features).normal_()

lstm = t.nn.LSTM(input_size=2, hidden_size=3, bidirectional=True, batch_first=True)

output, (h_n, c_n) = lstm(data)
[output.shape, h_n.shape, c_n.shape]

[torch.Size([2, 10, 6]), torch.Size([2, 2, 3]), torch.Size([2, 2, 3])]

class Net(t.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.lstm_1 = t.nn.LSTM(input_size=2, hidden_size=3, bidirectional=True, batch_first=True)
self.lstm_2 = t.nn.LSTM(input_size=2*3, hidden_size=4, bidirectional=True, batch_first=True)

def forward(self, input):
output, (h_n, c_n) = self.lstm_1(input)
output, (h_n, c_n) = self.lstm_2(output)
return output

net = Net()

net(data).shape

torch.Size([2, 10, 8])

关于python - 如何使用 PyTorch 为堆叠式 LSTM 模型执行 return_sequences?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60423030/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com