gpt4 book ai didi

keras - 在不使用嵌入的情况下在 keras 中屏蔽 LSTM 中的零输入

转载 作者:行者123 更新时间:2023-12-03 17:03:40 24 4
gpt4 key购买 nike

我正在 Keras 中训练 LSTM:

iclf = Sequential()
iclf.add(Bidirectional(LSTM(units=10, return_sequences=True, recurrent_dropout=0.3), input_shape=(None,2048)))
iclf.add(TimeDistributed(Dense(1, activation='sigmoid')))

每个单元格的输入是一个 2048 向量,它是已知的,不需要学习(如果你愿意,它们是输入句子中单词的 ELMo 嵌入)。因此,这里我没有Embedding层。

由于输入序列具有可变长度,因此使用 pad_sequences 填充它们。 :
X = pad_sequences(sequences=X, padding='post', truncating='post', value=0.0, dtype='float32')

现在,我想告诉 LSTM 忽略这些填充元素。官方的方法是使用带有 mask_zero=True的Embedding层.但是,这里没有嵌入层。如何通知 LSTM 屏蔽零元素?

最佳答案

正如@Today 在评论中所建议的,您可以使用 Masking层。这里我添加了一个玩具问题。

# lstm autoencoder recreate sequence
from numpy import array
from keras.models import Sequential
from keras.layers import LSTM, Masking
from keras.layers import Dense
from keras.layers import RepeatVector
from keras.layers import TimeDistributed
from keras.utils import plot_model
from keras.preprocessing.sequence import pad_sequences


# define input sequence
sequence = array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
[0.3, 0.4, 0.5, 0.6]])
# make sure to use dtype='float32' in padding otherwise with floating points
sequence = pad_sequences(sequence, padding='post', dtype='float32')


# reshape input into [samples, timesteps, features]
n_obs = len(sequence)
n_in = 9
sequence = sequence.reshape((n_obs, n_in, 1))

# define model
model = Sequential()
model.add(Masking(mask_value=0, input_shape=(n_in, 1)))
model.add(LSTM(100, activation='relu', input_shape=(n_in,1) ))
model.add(RepeatVector(n_in))
model.add(LSTM(100, activation='relu', return_sequences=True))
model.add(TimeDistributed(Dense(1)))
model.compile(optimizer='adam', loss='mse')
# fit model
model.fit(sequence, sequence, epochs=300, verbose=0)
plot_model(model, show_shapes=True, to_file='reconstruct_lstm_autoencoder.png')
# demonstrate recreation
yhat = model.predict(sequence, verbose=0)
print(yhat[0,:,0])

关于keras - 在不使用嵌入的情况下在 keras 中屏蔽 LSTM 中的零输入,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53172852/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com