gpt4 book ai didi

python - pytorch/torchtext 中的单热编码

转载 作者:行者123 更新时间:2023-12-05 07:21:21 24 4
gpt4 key购买 nike

我有一个来自 torchtextBucketiterator,我将它提供给 pytorch 中的模型。如何构造迭代器的示例:

train_iter, val_iter = BucketIterator.splits((train,val),
batch_size=batch_size,
sort_within_batch = True,
device = device,
shuffle=True,
sort_key=lambda x: (len(x.src), len(x.trg)))

然后将数据提供给这样的模型,我在其中使用 nn.Embedding 层。

class encoder(nn.Module):
def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout):
super().__init__()

self.input_dim = input_dim
self.emb_dim = emb_dim
self.hid_dim = hid_dim
self.n_layers = n_layers
self.dropout = dropout

self.embedding = nn.Embedding(input_dim, emb_dim)

self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout = dropout)

self.dropout = nn.Dropout(dropout)

def forward(self, src):

#src = [src sent len, batch size]

embedded = self.dropout(self.embedding(src))

#embedded = [src sent len, batch size, emb dim]
hidden_enc = []
outputs, hidden = self.rnn(embedded[0,:,:].unsqueeze(0))
for i in range(1,len(embedded[:,1,1])):
outputs, hidden = self.rnn(embedded[i,:,:].unsqueeze(0),hidden)
hidden_cpu = []
for k in range(len(hidden)):
hidden_cpu.append(hidden[k])
hidden_cpu[k] = hidden[k].cpu()
hidden_enc.append(tuple(hidden_cpu))



#outputs, hidden = self.rnn(embedded)

#outputs = [src sent len, batch size, hid dim * n directions]
#hidden = [n layers * n directions, batch size, hid dim]
#cell = [n layers * n directions, batch size, hid dim]
None
#outputs are always from the top hidden layer

return hidden, hidden_enc

但是如果我希望嵌入是单热编码的呢?我从事形式语言的工作,最好能保持标记之间的正交性。 pytorchtorchtext 似乎没有执行此操作的任何功能。

最佳答案

def get_one_hot_torch_tensor(in_tensor):"""函数将 1d 或 2d torch 张量转换为单热编码"""

n_channels = torch.max(in_tensor)+1  # maximum number of channels
if in_tensor.ndim == 2:
out_one_hot = torch.zeros((n_channels, in_tensor.shape[0], in_tensor.shape[1]))
# print(out_one_hot)
index = np.indices((in_tensor.shape[0], in_tensor.shape[1])) # create an array of indices
x, y = index[0], index[1]
print(x, y)

out_one_hot[in_tensor, x, y] = 1
print(out_one_hot)

关于python - pytorch/torchtext 中的单热编码,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56944018/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com