gpt4 book ai didi

pytorch - 在 PyTorch 中实现 Luong Attention

转载 作者:行者123 更新时间:2023-12-02 15:01:08 25 4
gpt4 key购买 nike

我正在尝试实现 Luong et al. 2015 中描述的注意力我自己在 PyTorch 中,但我无法让它工作。下面是我的代码,我现在只对“一般”注意情况感兴趣。我想知道我是否遗漏了任何明显的错误。它可以运行,但似乎没有学习。

class AttnDecoderRNN(nn.Module):
def __init__(self, hidden_size, output_size, dropout_p=0.1):
super(AttnDecoderRNN, self).__init__()
self.hidden_size = hidden_size
self.output_size = output_size
self.dropout_p = dropout_p

self.embedding = nn.Embedding(
num_embeddings=self.output_size,
embedding_dim=self.hidden_size
)
self.dropout = nn.Dropout(self.dropout_p)
self.gru = nn.GRU(self.hidden_size, self.hidden_size)
self.attn = nn.Linear(self.hidden_size, self.hidden_size)
# hc: [hidden, context]
self.Whc = nn.Linear(self.hidden_size * 2, self.hidden_size)
# s: softmax
self.Ws = nn.Linear(self.hidden_size, self.output_size)

def forward(self, input, hidden, encoder_outputs):
embedded = self.embedding(input).view(1, 1, -1)
embedded = self.dropout(embedded)

gru_out, hidden = self.gru(embedded, hidden)

# [0] remove the dimension of directions x layers for now
attn_prod = torch.mm(self.attn(hidden)[0], encoder_outputs.t())
attn_weights = F.softmax(attn_prod, dim=1) # eq. 7/8
context = torch.mm(attn_weights, encoder_outputs)

# hc: [hidden: context]
out_hc = F.tanh(self.Whc(torch.cat([hidden[0], context], dim=1)) # eq.5
output = F.log_softmax(self.Ws(out_hc), dim=1) eq. 6

return output, hidden, attn_weights

我研究了

中实现的注意力

https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html

https://github.com/spro/practical-pytorch/blob/master/seq2seq-translation/seq2seq-translation.ipynb

  • 第一个并不是我正在寻找的确切注意力机制。一个主要缺点是它的注意力取决于序列长度( self.attn = nn.Linear(self.hidden_size * 2, self.max_length) ),这对于长序列来说可能会很昂贵。
  • 第二个与论文中描述的更相似,但仍然不一样,因为没有tanh 。此外,更新到最新版本的pytorch( ref )后,速度真的很慢。我也不知道为什么它需要最后一个上下文( ref )。

最佳答案

这个版本有效,并且严格遵循 Luong Attention(一般)的定义。与问题中的主要区别在于 embedding_sizehidden_​​size 的分离,这对于实验后的训练似乎很重要。之前,我将它们都设置为相同的大小(256),这给学习带来了麻烦,并且看起来网络只能学习一半的序列。

class EncoderRNN(nn.Module):
def __init__(self, input_size, embedding_size, hidden_size,
num_layers=1, bidirectional=False, batch_size=1):
super(EncoderRNN, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bidirectional = bidirectional
self.batch_size = batch_size

self.embedding = nn.Embedding(input_size, embedding_size)

self.gru = nn.GRU(embedding_size, hidden_size, num_layers,
bidirectional=bidirectional)

def forward(self, input, hidden):
embedded = self.embedding(input).view(1, 1, -1)
output, hidden = self.gru(embedded, hidden)
return output, hidden

def initHidden(self):
directions = 2 if self.bidirectional else 1
return torch.zeros(
self.num_layers * directions,
self.batch_size,
self.hidden_size,
device=DEVICE
)


class AttnDecoderRNN(nn.Module):
def __init__(self, embedding_size, hidden_size, output_size, dropout_p=0):
super(AttnDecoderRNN, self).__init__()
self.embedding_size = embedding_size
self.hidden_size = hidden_size
self.output_size = output_size
self.dropout_p = dropout_p

self.embedding = nn.Embedding(
num_embeddings=output_size,
embedding_dim=embedding_size
)
self.dropout = nn.Dropout(self.dropout_p)
self.gru = nn.GRU(embedding_size, hidden_size)
self.attn = nn.Linear(hidden_size, hidden_size)
# hc: [hidden, context]
self.Whc = nn.Linear(hidden_size * 2, hidden_size)
# s: softmax
self.Ws = nn.Linear(hidden_size, output_size)

def forward(self, input, hidden, encoder_outputs):
embedded = self.embedding(input).view(1, 1, -1)
embedded = self.dropout(embedded)

gru_out, hidden = self.gru(embedded, hidden)

attn_prod = torch.mm(self.attn(hidden)[0], encoder_outputs.t())
attn_weights = F.softmax(attn_prod, dim=1)
context = torch.mm(attn_weights, encoder_outputs)

# hc: [hidden: context]
hc = torch.cat([hidden[0], context], dim=1)
out_hc = F.tanh(self.Whc(hc))
output = F.log_softmax(self.Ws(out_hc), dim=1)

return output, hidden, attn_weights

关于pytorch - 在 PyTorch 中实现 Luong Attention,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50571991/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com