gpt4 book ai didi

lstm - 需要帮助理解 ConvLSTM 代码在 pytorch 中的实现

转载 作者:行者123 更新时间:2023-12-01 16:06:11 27 4
gpt4 key购买 nike

我在理解 ConvLSTM 的以下实现时遇到问题。我不太明白 input_size +hidden_​​size 是什么?还有输出的 4 * hide_size 值? model = ConvLSTMCell(c, d)告诉我们c和d是input_size和hidden_​​size,分别是3和5。我假设 c 是 channel ,d 代表输出维度?另外,为什么我们要将hidden_​​size 乘以4 来输出? 4 是 LSTM 单元默认的 4 个门吗?另一种经验可以向我解释卷积中发生了什么吗?谢谢。

self.Gates = nn.Conv2d(输入大小 + 隐藏大小,4 * 隐藏大小,
KERNEL_SIZE,填充=PADDING)

import torch
from torch import nn
import torch.nn.functional as f
from torch.autograd import Variable


# Define some constants
KERNEL_SIZE = 3
PADDING = KERNEL_SIZE // 2


class ConvLSTMCell(nn.Module):
"""
Generate a convolutional LSTM cell
"""

def __init__(self, input_size, hidden_size):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.Gates = nn.Conv2d(input_size + hidden_size, 4 * hidden_size, KERNEL_SIZE, padding=PADDING)

def forward(self, input_, prev_state):

# get batch and spatial sizes
batch_size = input_.data.size()[0]
spatial_size = input_.data.size()[2:]

# generate empty prev_state, if None is provided
if prev_state is None:
state_size = [batch_size, self.hidden_size] + list(spatial_size)
prev_state = (
Variable(torch.zeros(state_size)),
Variable(torch.zeros(state_size))
)

prev_hidden, prev_cell = prev_state

# data size is [batch, channel, height, width]
stacked_inputs = torch.cat((input_, prev_hidden), 1)
gates = self.Gates(stacked_inputs)

# chunk across channel dimension
in_gate, remember_gate, out_gate, cell_gate = gates.chunk(4, 1)

# apply sigmoid non linearity
in_gate = f.sigmoid(in_gate)
remember_gate = f.sigmoid(remember_gate)
out_gate = f.sigmoid(out_gate)

# apply tanh non linearity
cell_gate = f.tanh(cell_gate)

# compute current cell and hidden state
cell = (remember_gate * prev_cell) + (in_gate * cell_gate)
hidden = out_gate * f.tanh(cell)

return hidden, cell


def _main():
"""
Run some basic tests on the API
"""

# define batch_size, channels, height, width
b, c, h, w = 1, 3, 4, 8
d = 5 # hidden state size
lr = 1e-1 # learning rate
T = 6 # sequence length
max_epoch = 20 # number of epochs

# set manual seed
torch.manual_seed(0)

print('Instantiate model')
model = ConvLSTMCell(c, d)
print(repr(model))

print('Create input and target Variables')
x = Variable(torch.rand(T, b, c, h, w))
y = Variable(torch.randn(T, b, d, h, w))

print('Create a MSE criterion')
loss_fn = nn.MSELoss()

print('Run for', max_epoch, 'iterations')
for epoch in range(0, max_epoch):
state = None
loss = 0
for t in range(0, T):
state = model(x[t], state)
loss += loss_fn(state[0], y[t])

print(' > Epoch {:2d} loss: {:.3f}'.format((epoch+1), loss.data[0]))

# zero grad parameters
model.zero_grad()

# compute new grad parameters through time!
loss.backward()

# learning_rate step against the gradient
for p in model.parameters():
p.data.sub_(p.grad.data * lr)

print('Input size:', list(x.data.size()))
print('Target size:', list(y.data.size()))
print('Last hidden state size:', list(state[0].size()))


if __name__ == '__main__':
_main()


__author__ = "Alfredo Canziani"
__credits__ = ["Alfredo Canziani"]
__maintainer__ = "Alfredo Canziani"
__email__ = "alfredo.canziani@gmail.com"
__status__ = "Prototype" # "Prototype", "Development", or "Production"
__date__ = "Jan 17"

最佳答案

是的,你是对的,输出是按4计时的,因为有4个门

关于lstm - 需要帮助理解 ConvLSTM 代码在 pytorch 中的实现,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50914300/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com