gpt4 book ai didi

python - ValueError : Error when checking input: expected lstm_1_input to have 3 dimensions, 但得到形状为 (10, 1) 的数组

转载 作者:太空宇宙 更新时间:2023-11-04 09:48:12 29 4
gpt4 key购买 nike

我正在努力解决 LSTM input_shape 问题。在这里,我制作了一个简单的 LSTM 网络,应该对其进行训练,以将输入加倍。

from keras.models import Sequential
from keras.layers import LSTM, Dense
import numpy as np

X = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
y = np.array([2, 4, 6, 8, 10, 12, 14, 16, 18, 20])

data_dim = 1
timesteps = 8

model = Sequential()
model.add(LSTM(32, return_sequences=True, input_shape=(timesteps, data_dim)))
model.add(LSTM(32, return_sequences=True))
model.add(Dense(10, activation='softmax'))

model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])

model.fit(X,y, batch_size=10, epochs=1000)

但是总是出现这样的错误信息:ValueError:检查输入时出错:预期 lstm_1_input 具有 3 个维度,但得到形状为 (10, 1) 的数组我究竟做错了什么?有人可以向我解释 input_shape 的事情吗?亲切的问候。尼克拉斯

最佳答案

您的代码有很多问题。

1) 你想要一个回归问题。在最后一层,softmax 会将数字压缩到 0 和 1 之间。您需要线性激活。

2) 因此,损失函数应该是mean_square_error

3) 目标 y 的形状决定了每个时间步的输出层大小应该是 1 而不是 10。

4) LSTM 层的输入和输出数组的形状应该是 (batch_size, time_step, dim)。

5) LSTM 层中定义的时间步长与输入数据的时间步长应该相同。

我将这些更改合并到您的代码中:

from keras.models import Sequential
from keras.layers import LSTM, Dense
import numpy as np

X = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
y = np.array([2, 4, 6, 8, 10, 12, 14, 16, 18, 20])

X = X.reshape(1,10,1)
y = y.reshape(1,10,1)

data_dim = 1
timesteps = 10

model = Sequential()
model.add(LSTM(32, return_sequences=True, input_shape=(timesteps, data_dim)))
model.add(LSTM(32, return_sequences=True))
model.add(Dense(1, activation='linear'))

print(model.summary())

model.compile(loss='mean_squared_error', optimizer='rmsprop', metrics=['accuracy'])

model.fit(X,y, batch_size=1, epochs=1000)

关于python - ValueError : Error when checking input: expected lstm_1_input to have 3 dimensions, 但得到形状为 (10, 1) 的数组,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48978609/

29 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com