gpt4 book ai didi

python-3.x - LSTM 的预期隐藏状态维度没有考虑批量大小

转载 作者:行者123 更新时间:2023-12-04 17:40:50 26 4
gpt4 key购买 nike

我有这个解码器模型,它应该将成批的句子嵌入(batchsize = 50,hidden size = 300)作为输入并输出一批预测句子的热点表示:

class DecoderLSTMwithBatchSupport(nn.Module):
# Your code goes here
def __init__(self, embedding_size,batch_size, hidden_size, output_size):
super(DecoderLSTMwithBatchSupport, self).__init__()
self.hidden_size = hidden_size
self.batch_size = batch_size
self.lstm = nn.LSTM(input_size=embedding_size,num_layers=1, hidden_size=hidden_size, batch_first=True)
self.out = nn.Linear(hidden_size, output_size)
self.softmax = nn.LogSoftmax(dim=1)

def forward(self, my_input, hidden):
print(type(my_input), type(hidden))
output, hidden = self.lstm(my_input, hidden)
output = self.softmax(self.out(output[0]))
return output, hidden

def initHidden(self):
return Variable(torch.zeros(1, self.batch_size, self.hidden_size)).cuda()

但是,当我使用以下命令运行它时:
decoder=DecoderLSTMwithBatchSupport(vocabularySize,batch_size, 300, vocabularySize)
decoder.cuda()
decoder_input=np.zeros([batch_size,vocabularySize])
for i in range(batch_size):
decoder_input[i] = embeddings[SOS_token]
decoder_input=Variable(torch.from_numpy(decoder_input)).cuda()
decoder_hidden = (decoder.initHidden(),decoder.initHidden())
for di in range(target_length):
decoder_output, decoder_hidden = decoder(decoder_input.view(1,batch_size,-1), decoder_hidden)

我得到他以下错误:

Expected hidden[0] size (1, 1, 300), got (1, 50, 300)



为了使模型期望批量隐藏状态,我缺少什么?

最佳答案

创建 LSTM 时,国旗batch_first不是必需的,因为它假定您的输入具有不同的形状。从文档:

If True, then the input and output tensors are provided as (batch, seq, feature). Default: False



将 LSTM 创建更改为:
self.lstm = nn.LSTM(input_size=embedding_size, num_layers=1, hidden_size=hidden_size)

此外,还有一个类型错误。创建 decoder_input 时使用 torch.from_numpy()它有一个 dtype=torch.float64 , 而 decoder_input默认有 dtype=torch.float32 .更改您创建 decoder_input 的行像
decoder_input = Variable(torch.from_numpy(decoder_input)).cuda().float()

有了这两个更改,它应该可以正常工作:)

关于python-3.x - LSTM 的预期隐藏状态维度没有考虑批量大小,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54566209/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com