gpt4 book ai didi

python - 如何在 RNN 中嵌入句子序列?

转载 作者:行者123 更新时间:2023-12-02 02:23:49 25 4
gpt4 key购买 nike

我正在尝试制作一个 RNN 模型(在 Pytorch 中),它需要几个句子,然后将其分类为 Class 0Class 1。 p>

对于这个问题,我们假设句子的 max_len 为 4,max_amount of time steps 为 5。因此,每个数据点都在表单上(0 是用于填充填充值的值):

    x[1] = [
# Input features at timestep 1
[1, 48, 91, 0],
# Input features at timestep 2
[20, 5, 17, 32],
# Input features at timestep 3
[12, 18, 0, 0],
# Input features at timestep 4
[0, 0, 0, 0],
# Input features at timestep 5
[0, 0, 0, 0]
]
y[1] = [1]

当我每个目标只有一个句子时:我只是将每个单词传递到嵌入层,然后传递到 LSTM 或 GRU,但我有点卡住了当我每个目标有一系列句子时该怎么办?

如何构建可以处理句子的嵌入?

最佳答案

最简单的方法是使用两种 LSTM。

准备玩具数据集

xi = [
# Input features at timestep 1
[1, 48, 91, 0],
# Input features at timestep 2
[20, 5, 17, 32],
# Input features at timestep 3
[12, 18, 0, 0],
# Input features at timestep 4
[0, 0, 0, 0],
# Input features at timestep 5
[0, 0, 0, 0]
]
yi = 1

x = torch.tensor([xi, xi])
y = torch.tensor([yi, yi])

print(x.shape)
# torch.Size([2, 5, 4])

print(y.shape)
# torch.Size([2])

然后,x 是输入的批处理。这里batch_size = 2。

嵌入输入

vocab_size = 1000
embed_size = 100
hidden_size = 200
embed = nn.Embedding(vocab_size, embed_size)

# shape [2, 5, 4, 100]
x = embed(x)

第一个词-LSTM是将每个序列编码成向量

# convert x into a batch of sequences
# Reshape into [2, 20, 100]
x = x.view(bs * 5, 4, 100)

wlstm = nn.LSTM(embed_size, hidden_size, batch_first=True)
# get the only final hidden state of each sequence

_, (hn, _) = wlstm(x)

# hn shape [1, 10, 200]

# get the output of final layer
hn = hn[0] # [10, 200]

第二个 seq-LSTM 是将序列编码为单个向量

# Reshape hn into [bs, num_seq, hidden_size]
hn = hn.view(2, 5, 200)

# Pass to another LSTM and get the final state hn
slstm = nn.LSTM(hidden_size, hidden_size, batch_first=True)
_, (hn, _) = slstm(hn) # [1, 2, 200]

# Similarly, get the hidden state of the last layer
hn = hn[0] # [2, 200]

添加一些分类层

pred_linear = nn.Linear(hidden_size, 1)

# [2, 1]
output = torch.sigmoid(pred_linear(hn))

关于python - 如何在 RNN 中嵌入句子序列?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60186944/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com