gpt4 book ai didi

pytorch - 如何在 PyTorch 中正确实现批量输入 LSTM 网络?

转载 作者:行者123 更新时间:2023-12-03 22:35:26 25 4
gpt4 key购买 nike

这个release PyTorch 似乎提供了 PackedSequence用于循环神经网络的可变长度输入。但是,我发现正确使用它有点困难。

使用 pad_packed_sequence恢复由 pack_padded_sequence 提供的 RNN 层的输出, 我们得到了 T x B x N张量 outputs在哪里 T是最大时间步长,B是批量大小和 N是隐藏的大小。我发现对于批处理中的短序列,后续输出将全为零。

这是我的问题。

  • 对于需要所有序列的最后一个输出的单个输出任务,简单的 outputs[-1]会给出错误的结果,因为这个张量包含很多短序列的零。需要按序列长度构建索引以获取所有序列的单个最后输出。有没有更简单的方法来做到这一点?
  • 对于多输出任务(例如 seq2seq),通常会添加一个线性层 N x O并 reshape 批量输出 T x B x O进入 TB x O并计算与真实目标的交叉熵损失 TB (通常是语言模型中的整数)。在这种情况下,批量输出中的这些零重要吗?
  • 最佳答案

    问题 1 - 最后时间步长

    这是我用来获取最后一个时间步的输出的代码。我不知道是否有更简单的解决方案。如果是,我想知道。我关注了这个 discussion并为我的 last_timestep 获取了相关代码片段方法。这是我的前锋。

    class BaselineRNN(nn.Module):
    def __init__(self, **kwargs):
    ...

    def last_timestep(self, unpacked, lengths):
    # Index of the last output for each sequence.
    idx = (lengths - 1).view(-1, 1).expand(unpacked.size(0),
    unpacked.size(2)).unsqueeze(1)
    return unpacked.gather(1, idx).squeeze()

    def forward(self, x, lengths):
    embs = self.embedding(x)

    # pack the batch
    packed = pack_padded_sequence(embs, list(lengths.data),
    batch_first=True)

    out_packed, (h, c) = self.rnn(packed)

    out_unpacked, _ = pad_packed_sequence(out_packed, batch_first=True)

    # get the outputs from the last *non-masked* timestep for each sentence
    last_outputs = self.last_timestep(out_unpacked, lengths)

    # project to the classes using a linear layer
    logits = self.linear(last_outputs)

    return logits

    问题 2 - 掩蔽交叉熵损失

    是的,默认情况下,零填充时间步(目标)很重要。但是,很容易掩盖它们。您有两个选择,具体取决于您使用的 PyTorch 版本。
  • PyTorch 0.2.0 : 现在 pytorch 支持直接在 CrossEntropyLoss 中屏蔽, 与 ignore_index争论。例如,在语言建模或 seq2seq 中,我添加零填充,我像这样简单地屏蔽零填充词(目标):

    loss_function = nn.CrossEntropyLoss(ignore_index=0)
  • PyTorch 0.1.12和更早版本:在旧版本的 PyTorch 中,不支持屏蔽,因此您必须实现自己的解决方法。我使用的解决方案是 masked_cross_entropy.py , 由 jihunchoi .您可能对此也感兴趣 discussion .
  • 关于pytorch - 如何在 PyTorch 中正确实现批量输入 LSTM 网络?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46387661/

    25 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com