gpt4 book ai didi

python - 我如何知道 keras RNN 输入数据的正确格式?

转载 作者:行者123 更新时间:2023-11-30 09:05:03 24 4
gpt4 key购买 nike

我正在尝试构建一个 Elman 简单 RNN,如 here 中所述。 。

我使用 Keras 构建了模型,如下所示:

model = keras.Sequential()
model.add(keras.layers.SimpleRNN(7,activation =None,use_bias=True,input_shape=
[x_train.shape[0],x_train.shape[1]]))
model.add(keras.layers.Dense(7,activation = tf.nn.sigmoid))

model.summary()
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
simple_rnn_2 (SimpleRNN) (None, 7) 105
_________________________________________________________________
dense_2 (Dense) (None, 7) 56
=================================================================
Total params: 161
Trainable params: 161
Non-trainable params: 0
_________________________________________________________________

我的训练数据目前的形状为 (15000, 7, 7)。也就是说,15000 个长度为 7 的一热编码实例,对七个字母之一进行编码。例如[0,0,0,1,0,0,0],[0,0,0,0,1,0,0]等等。

数据的标签格式相同,因为每个字母都预测序列中的下一个字母,即 [0,1,0,0,0,0,0]有标签[0,0,1,0,0,0,0] .

所以,训练数据(x_train)和培训标签(y_train)形状均为(15000,7,7) .

我的验证数据 x_val 和 y_val 的形状为 (10000,7,7) 。即相同的形状,只是实例较少。

所以当我运行我的模型时:

history = model.fit(x_train,
y_train,
epochs = 40,
batch_size=512,
validation_data = (x_val,y_val))

我收到错误:

ValueError: Error when checking input: expected simple_rnn_7_input to have shape (15000, 7) but got array with shape (7, 7)

显然,我的输入数据格式不正确,无法输入 Keras RNN,但我不知道如何为其提供正确的输入。

有人可以告诉我解决方案吗?

最佳答案

  1. SimpleRNN 层需要维度 (seq_length, input_dim) 的输入在你的例子中,即 (7,7)。
  2. 此外,如果您想要每个时间步的输出,则需要使用 return_sequence=True ,默认为false 。这样您就可以比较时间步长的输出。

所以模型架构将是这样的:

model.add(keras.layers.SimpleRNN(7, activation='tanh', 
return_sequences=True,
input_shape=[7,7]))
model.add(keras.layers.Dense(7))
model.summary()

_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
simple_rnn_12 (SimpleRNN) (None, 7, 7) 105
_________________________________________________________________
dense_2 (Dense) (None, 7, 7) 56
=================================================================
Total params: 161
Trainable params: 161
Non-trainable params: 0
_________________________________________________________________

现在在训练时,它需要数据 input and output暗淡 (num_samples, seq_length, input_dims)(15000, 7, 7)对于两者。

model.compile(loss='categorical_crossentropy', optimizer='adam')# define any loss, you want
model.fit(x_train, y_train, epochs=2)

关于python - 我如何知道 keras RNN 输入数据的正确格式?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54118715/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com