gpt4 book ai didi

python - 从 PyTorch 中的 BiLSTM (BiGRU) 获取最后一个状态

转载 作者:行者123 更新时间:2023-12-04 05:57:40 45 4
gpt4 key购买 nike

在阅读了几篇文章后,我仍然对我从 BiLSTM 获取最后隐藏状态的实现的正确性感到困惑。

  • Understanding Bidirectional RNN in PyTorch (TowardsDataScience)
  • PackedSequence for seq2seq model (PyTorch forums)
  • What's the difference between “hidden” and “output” in PyTorch LSTM? (StackOverflow)
  • Select tensor in a batch of sequences (Pytorch formums)

  • 最后一个来源(4)的方法对我来说似乎是最干净的,但我仍然不确定我是否正确理解了该线程。我是否使用了来自 LSTM 和反向 LSTM 的正确最终隐藏状态?这是我的实现
    # pos contains indices of words in embedding matrix
    # seqlengths contains info about sequence lengths
    # so for instance, if batch_size is 2 and pos=[4,6,9,3,1] and
    # seqlengths contains [3,2], we have batch with samples
    # of variable length [4,6,9] and [3,1]

    all_in_embs = self.in_embeddings(pos)
    in_emb_seqs = pack_sequence(torch.split(all_in_embs, seqlengths, dim=0))
    output,lasthidden = self.rnn(in_emb_seqs)
    if not self.data_processor.use_gru:
    lasthidden = lasthidden[0]
    # u_emb_batch has shape batch_size x embedding_dimension
    # sum last state from forward and backward direction
    u_emb_batch = lasthidden[-1,:,:] + lasthidden[-2,:,:]

    这是正确的吗?

    最佳答案

    在一般情况下,如果您想创建自己的 BiLSTM 网络,您需要创建两个常规 LSTM,并使用常规输入序列馈送一个,另一个使用反向输入序列馈送。在完成两个序列的输入后,您只需从两个网络中获取最后一个状态并以某种方式将它们联系在一起(求和或连接)。

    据我了解,您正在使用 this example 中的内置 BiLSTM(在 nn.LSTM 构造函数中设置 bidirectional=True)。然后您在输入批次后获得连接的输出,因为 PyTorch 会为您处理所有麻烦。

    如果是这种情况,并且您想对隐藏状态求和,那么您必须

    u_emb_batch = (lasthidden[0, :, :] + lasthidden[1, :, :])

    假设你只有一层。如果你有更多层,你的变体看起来更好。

    这是因为结果是结构化的(参见 documentation ):

    h_n of shape (num_layers * num_directions, batch, hidden_size): tensor containing the hidden state for t = seq_len



    顺便一提,
    u_emb_batch_2 = output[-1, :, :HIDDEN_DIM] + output[-1, :, HIDDEN_DIM:]

    应该提供相同的结果。

    关于python - 从 PyTorch 中的 BiLSTM (BiGRU) 获取最后一个状态,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50856936/

    45 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com