gpt4 book ai didi

keras - 在 PyTorch 中准备序列到序列网络的解码器

转载 作者:行者123 更新时间:2023-12-04 02:28:50 24 4
gpt4 key购买 nike

我在 Pytorch 中使用 Sequence to Sequence 模型。序列到序列模型由编码器和解码器组成。

编码器转换 (batch_size X input_features X num_of_one_hot_encoded_classes) -> (batch_size X input_features X hidden_size)
解码器将把这个输入序列转换成 (batch_size X output_features X num_of_one_hot_encoded_classes)
一个例子就像 -

enter image description here

所以在上面的例子中,我需要将 22 个输入特征转换为 10 个输出特征。在 Keras 中,可以使用 RepeatVector(10) 来完成。

一个例子 -

model.add(LSTM(256, input_shape=(22, 98)))
model.add(RepeatVector(10))
model.add(Dropout(0.3))
model.add(LSTM(256, return_sequences=True))

虽然,我不确定这是否是将输入序列转换为输出序列的正确方法。

所以,我的问题是——
  • 将输入序列转换为的标准方法是什么
    输出的。例如。从 (batch_size, 22, 98) -> (batch_size,
    10, 98)?或者我应该如何准备解码器?

  • 编码器代码片段(用 Pytorch 编写)-
    class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size):
    super(EncoderRNN, self).__init__()
    self.hidden_size = hidden_size
    self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size,
    num_layers=1, batch_first=True)

    def forward(self, input):
    output, hidden = self.lstm(input)
    return output, hidden

    最佳答案

    好吧,您必须选择,第一个是将编码器的最后状态重复 10 次并将其作为解码器的输入,如下所示:

    import torch
    input = torch.randn(64, 22, 98)
    encoder = torch.nn.LSTM(98, 256, batch_first=True)
    encoded, _ = encoder(input)
    decoder_input = encoded[:, -1:].repeat(1, 10, 1)
    decoder = torch.nn.LSTM(256, 98, batch_first=True)
    decoded, _ = decoder(decoder_input)
    print(decoded.shape) #torch.Size([64, 10, 98])

    另一种选择是使用注意力机制,如下所示:
    #assuming we have obtained the encoded sequence and declared the decoder as before
    attention_calculator = torch.nn.Conv1d(256+98, 1, kernel_size=1)
    hidden = (torch.zeros(1, 64, 98), torch.zeros(1, 64, 98))
    outputs = []
    for i in range(10):
    attention_input = torch.cat([hidden[0][0][:, None, :].expand(-1, 22, -1), encoded], dim=2).permute(0, 2, 1)
    attention_value = torch.nn.functional.softmax(attention_calculator(attention_input).squeeze(), dim=1)
    decoder_input = (attention_value[:, :, None] * encoded).sum(dim=1, keepdim=True)
    output, hidden = decoder(decoder_input, hidden)
    outputs.append(output)
    outputs = torch.cat(outputs, dim=1)

    关于keras - 在 PyTorch 中准备序列到序列网络的解码器,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52437708/

    24 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com