gpt4 book ai didi

python-3.x - 如何在 keras 序列模型中添加注意力层(以及 Bi-LSTM 层)?

转载 作者:行者123 更新时间:2023-12-02 03:11:40 32 4
gpt4 key购买 nike

我正在尝试找到一种简单的方法来在 Keras 顺序模型中添加注意层。然而,我在实现这一目标时遇到了很多问题。

我是深度学习的新手,所以我选择 Keras 作为我的起点。我的任务是建立一个带有注意力模型的 Bi-LSTM。在 IMDB 数据集上,我构建了 Bi-LSTM 模型。我发现了一个名为 'keras-self-attention' 的包( https://pypi.org/project/keras-self-attention/ ),但是在keras Sequential模型中添加注意力层时遇到了一些问题。

from keras.datasets import imdb
from keras.preprocessing import sequence
from keras_self_attention import SeqSelfAttention

max_features = 10000
maxlen = 500
batch_size = 32

# data
(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)
x_train = sequence.pad_sequences(x_train, maxlen= maxlen)
x_test = sequence.pad_sequences(x_test, maxlen=maxlen)

# model
from keras import models
from keras import layers
from keras.layers import Dense, Embedding, LSTM


model = models.Sequential()
model.add( Embedding(max_features, 32) )
model.add( Bidirectional( LSTM(32) ) )
# add an attention layer
model3.add(SeqSelfAttention(activation='sigmoid') )
model.add( Dense(1, activation='sigmoid') )

# compile and fit
model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['acc'])
history = model.fit(x_train, y_train, epochs=10, batch_size=128, validation_split=0.2)

以上代码返回值错误,

ValueError                                Traceback (most recent call last)
<ipython-input-97-e6eb02d043c4> in <module>()
----> 1 history = model3.fit(x_train, y_train, epochs=10, batch_size=128, validation_split=0.2)

~/denglz/venv4re/lib/python3.6/site-packages/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)
950 sample_weight=sample_weight,
951 class_weight=class_weight,
--> 952 batch_size=batch_size)
953 # Prepare validation data.
954 do_validation = False

~/denglz/venv4re/lib/python3.6/site-packages/keras/engine/training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, check_array_lengths, batch_size)
787 feed_output_shapes,
788 check_batch_axis=False, # Don't enforce the batch size.
--> 789 exception_prefix='target')
790
791 # Generate sample-wise weight values given the `sample_weight` and

~/denglz/venv4re/lib/python3.6/site-packages/keras/engine/training_utils.py in standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
126 ': expected ' + names[i] + ' to have ' +
127 str(len(shape)) + ' dimensions, but got array '
--> 128 'with shape ' + str(data_shape))
129 if not check_batch_axis:
130 data_shape = data_shape[1:]

ValueError: Error when checking target: expected dense_7 to have 3 dimensions, but got array with shape (25000, 1)

所以发生了什么?我是深度学习的新手,如果您知道答案,请帮助我。

最佳答案

在您的代码中,注意层的输出与输入具有相同的形状(因此在本例中它是 3 维的)。

使用 SeqWeightedAttention 代替:

from keras.datasets import imdb
from keras.preprocessing import sequence
from keras_self_attention import SeqSelfAttention, SeqWeightedAttention

max_features = 10000
maxlen = 500
batch_size = 32

# data
(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)

x_train = sequence.pad_sequences(x_train, maxlen= maxlen)
x_test = sequence.pad_sequences(x_test, maxlen=maxlen)

# model
from keras import models
from keras import layers
from keras.layers import Dense, Embedding, LSTM, Bidirectional
model = models.Sequential()
# model.add( Embedding(max_features, 32, mask_zero=True))
model.add( Embedding(max_features, 32))
model.add(Bidirectional( LSTM(32, return_sequences=True)))
# add an attention layer

# model.add(SeqSelfAttention(attention_activation='sigmoid'))
model.add(SeqWeightedAttention())

model.add( Dense(1, activation='sigmoid') )

# compile and fit
model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['acc'])
model.summary()

history = model.fit(x_train, y_train, epochs=1, batch_size=128, validation_split=0.2)

Here's the code with output.

关于python-3.x - 如何在 keras 序列模型中添加注意力层(以及 Bi-LSTM 层)?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57438806/

32 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com