gpt4 book ai didi

keras - 如何在 Keras 中合并两个 LSTM 层

转载 作者:行者123 更新时间:2023-12-04 19:34:52 26 4
gpt4 key购买 nike

我正在与 Keras 合作执行句子相似性任务(使用 STS 数据集)并且在合并层时遇到问题。数据由 1184 个句子对组成,每个句子对的得分在 0 到 5 之间。以下是我的 numpy 数组的形状。我将每个句子填充为 50 个单词,并使用 100 维的手套嵌入将它们穿过嵌入层。合并两个网络时,我收到一个错误..

Exception: Error when checking model input: the list of Numpy arrays that you are passing to your model is not the size the model expected. Expected to see 1 arrays but instead got the following list of 2 arrays:

这是我的代码的样子
total training data = 1184
X1.shape = (1184, 50)
X2.shape = (1184, 50)
Y.shape = (1184, 1)


embedding_matrix = np.zeros((len(word_index) + 1, EMBEDDING_DIM))
for word, i in word_index.items():
embedding_vector = embeddings_index.get(word)
if embedding_vector is not None:
# words not found in embedding index will be all-zeros.
embedding_matrix[i] = embedding_vector

embedding_layer = Embedding(len(word_index) + 1,
EMBEDDING_DIM,
weights=[embedding_matrix],
input_length=50,
trainable=False)

s1rnn = Sequential()
s1rnn.add(embedding_layer)
s1rnn.add(LSTM(128, input_shape=(100, 1)))
s1rnn.add(Dense(1))

s2rnn = Sequential()
s2rnn.add(embedding_layer)
s2rnn.add(LSTM(128, input_shape=(100, 1)))
s2rnn.add(Dense(1))

model = Sequential()
model.add(Merge([s1rnn,s2rnn],mode='concat'))
model.add(Dense(1))
model.compile(loss='mean_squared_error', optimizer='RMSprop', metrics=['accuracy'])
model.fit([X1,X2], Y,batch_size=32, nb_epoch=100, validation_split=0.05)

最佳答案

问题不在于合并层。您需要创建两个嵌入层以输入 2 个不同的输入。

以下修改应该有效:

embedding_layer_1 = Embedding(len(word_index) + 1,
EMBEDDING_DIM,
weights=[embedding_matrix],
input_length=50,
trainable=False)

embedding_layer_2 = Embedding(len(word_index) + 1,
EMBEDDING_DIM,
weights=[embedding_matrix],
input_length=50,
trainable=False)


s1rnn = Sequential()
s1rnn.add(embedding_layer_1)
s1rnn.add(LSTM(128, input_shape=(100, 1)))
s1rnn.add(Dense(1))

s2rnn = Sequential()
s2rnn.add(embedding_layer_2)
s2rnn.add(LSTM(128, input_shape=(100, 1)))
s2rnn.add(Dense(1))

关于keras - 如何在 Keras 中合并两个 LSTM 层,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41052494/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com