gpt4 book ai didi

python - 手动双向 torch.nn.RNN 实现

转载 作者:行者123 更新时间:2023-12-05 05:44:53 24 4
gpt4 key购买 nike

我正在尝试在没有 C++/CUDA 绑定(bind)的情况下重新实现 torch.nn.RNN 模块,即使用简单的张量运算和相关逻辑。我开发了以下 RNN 类和相关的测试逻辑,可用于将输出与引用模块实例进行比较:

import torch
import torch.nn as nn


class RNN(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, bidirectional=False):
super(RNN, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bidirectional = bidirectional
self.w_ih = [torch.randn(hidden_size, input_size)]
if bidirectional:
self.w_ih_reverse = [torch.randn(hidden_size, input_size)]

for layer in range(num_layers - 1):
self.w_ih.append(torch.randn(hidden_size, hidden_size))
if bidirectional:
self.w_ih_reverse.append(torch.randn(hidden_size, hidden_size))

self.w_hh = torch.randn(num_layers, hidden_size, hidden_size)
if bidirectional:
self.w_hh_reverse = torch.randn(num_layers, hidden_size, hidden_size)

def forward(self, input, h_0=None):
if h_0 is None:
if self.bidirectional:
h_0 = torch.zeros(2, self.num_layers, input.shape[1], self.hidden_size)
else:
h_0 = torch.zeros(1, self.num_layers, input.shape[1], self.hidden_size)

if self.bidirectional:
output = torch.zeros(input.shape[0], input.shape[1], 2 * self.hidden_size)
else:
output = torch.zeros(input.shape[0], input.shape[1], self.hidden_size)

for t in range(input.shape[0]):
print(input.shape, t)
input_t = input[t]
if self.bidirectional:
input_t_reversed = input[-1 - t]

for layer in range(self.num_layers):
h_t = torch.tanh(torch.matmul(input_t, self.w_ih[layer].T) + torch.matmul(h_0[0][layer], self.w_hh[layer].T))
h_0[0][layer] = h_t
if self.bidirectional:
h_t_reverse = torch.tanh(torch.matmul(input_t_reversed, self.w_ih_reverse[layer].T) + torch.matmul(h_0[1][layer], self.w_hh_reverse[layer].T))
h_0[1][layer] = h_t_reverse

input_t = h_t
if self.bidirectional:
# This logic is incorrect for bidirectional RNNs with multiple layers
input_t = torch.cat((h_t, h_t_reverse), dim=-1)
input_t_reversed = input_t

output[t, :, :self.hidden_size] = h_t
if self.bidirectional:
output[-1 - t, :, self.hidden_size:] = h_t_reverse

return output


if __name__ == '__main__':
input_size = 10
hidden_size = 12
num_layers = 2
batch_size = 2
bidirectional = True
input = torch.randn(2, batch_size, input_size)
rnn_val = torch.nn.RNN(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bias=False, bidirectional=bidirectional, nonlinearity='tanh')
rnn = RNN(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bidirectional=bidirectional)
for i in range(rnn_val.num_layers):
rnn.w_ih[i] = rnn_val._parameters['weight_ih_l%d' % i].data
rnn.w_hh[i] = rnn_val._parameters['weight_hh_l%d' % i].data
if bidirectional:
rnn.w_ih_reverse[i] = rnn_val._parameters['weight_ih_l%d_reverse' % i].data
rnn.w_hh_reverse[i] = rnn_val._parameters['weight_hh_l%d_reverse' % i].data

output_val, hn_val = rnn_val(input)
output = rnn(input)
print(output_val)
print(output)

除了单层双向 RNN 之外,我的实现似乎适用于具有任意层数和不同批量大小/序列长度的普通 RNN,但是,它不会为多层双向 RNN 产生正确的结果。

为了简单起见,目前没有实现偏置项,只支持tanh激活函数。我已将逻辑错误缩小到行 input_t = torch.cat((h_t, h_t_reverse), dim=-1),因为第一个输出序列不正确。

如果有人能指出正确的方向并让我知道问题出在哪里,我将不胜感激!

最佳答案

有两种可能的转发方式:

  • 通过第一个元素逐步遍历所有层,然后逐步遍历时间(这个在上面的代码片段中采用)
  • 按顺序遍历层,计算每个索引的输出(或等效的下一层输入)

因此,虽然第一个适用于单向 RNN,但它不适用于双向多层 RNN。为了说明,让我们采用 2 层(代码中的相同情况)——用于计算 output[0] 需要来自前一层的输入,它是以下内容的串联:

  1. 来自长度 1 的正常传递的隐藏向量(因为它就在序列的开始)
  2. 和长度为 seq_length 的反向传递的隐藏向量(需要遍历整个序列,从头到尾,才能得到它)

因此,当首先通过层进行单步执行时,它只需要一步时间(遍历长度等于 1),因此 output[0] 将垃圾作为输入,因为第二部分不是正确(没有“从头到尾的整个过程”)。

为了克服这个问题,我建议重写正向循环:

for t in range(input.shape[0]):
...
for layer in range(self.num_layers):
...

类似于:

for layer in range(self.num_layers):
...
for t in range(input.shape[0]):
...

作为替代方案,在其他正常情况下保留forward用于计算,但对于双向多层编写另一个函数forward_bidir,并将此循环写在那里。

还值得注意的是,w_ih[k] 在双向情况下 k > 0 的形状为 (hidden_​​size, 2 * hidden_​​size),如 pytorch documentation on RNN 中所述.此外,函数 torch.allclose 应该比打印更好地用于测试输出。

对于代码修复检查 gist ,未进行任何优化,主要目的是保留原始想法,似乎适用于上面列出的所有配置(单向、多层、双向)。

关于python - 手动双向 torch.nn.RNN 实现,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/71522409/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com