gpt4 book ai didi

machine-learning - 烤宽面条中 LSTM 可能出现的问题

转载 作者:行者123 更新时间:2023-11-30 09:21:32 25 4
gpt4 key购买 nike

使用教程中给出的简单 LSTM 构造函数和维度 [,,1] 的输入,人们期望看到形状 [,,num_units]。但无论构造过程中传递的 num_units 是多少,输出的形状都与输入相同。

以下是复制此问题的最小代码...

    import lasagne
import theano
import theano.tensor as T
import numpy as np

num_batches= 20
sequence_length= 100
data_dim= 1
train_data_3= np.random.rand(num_batches,sequence_length,data_dim).astype(theano.config.floatX)

#As in the tutorial
forget_gate = lasagne.layers.Gate(b=lasagne.init.Constant(5.0))
l_lstm = lasagne.layers.LSTMLayer(
(num_batches,sequence_length, data_dim),
num_units=8,
forgetgate=forget_gate
)

lstm_in= T.tensor3(name='x', dtype=theano.config.floatX)

lstm_out = lasagne.layers.get_output(l_lstm, {l_lstm:lstm_in})
f = theano.function([lstm_in], lstm_out)
lstm_output_np= f(train_data_3)

lstm_output_np.shape
#= (20, 100, 1)

一个不合格的 LSTM(我的意思是在默认模式下)应该为每个单元产生一个输出,对吗?该代码在 kaixhin 的 cuda lasagne docker 镜像 docker image 上运行是什么赋予了?谢谢!

最佳答案

您可以使用 lasagne.layers.InputLayer 来解决这个问题

import lasagne
import theano
import theano.tensor as T
import numpy as np

num_batches= 20
sequence_length= 100
data_dim= 1
train_data_3= np.random.rand(num_batches,sequence_length,data_dim).astype(theano.config.floatX)

#As in the tutorial
forget_gate = lasagne.layers.Gate(b=lasagne.init.Constant(5.0))
input_layer = lasagne.layers.InputLayer(shape=(num_batches, # <-- change
sequence_length, data_dim),) # <-- change
l_lstm = lasagne.layers.LSTMLayer(input_layer, # <-- change
num_units=8,
forgetgate=forget_gate
)

lstm_in= T.tensor3(name='x', dtype=theano.config.floatX)

lstm_out = lasagne.layers.get_output(l_lstm, lstm_in) # <-- change
f = theano.function([lstm_in], lstm_out)
lstm_output_np= f(train_data_3)

print lstm_output_np.shape

如果将输入输入到 input_layer 中,它就不再含糊不清,因此您甚至不需要指定输入应该去的位置。直接指定形状并将tensor3添加到LSTM中是行不通的。

关于machine-learning - 烤宽面条中 LSTM 可能出现的问题,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/35625409/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com