gpt4 book ai didi

python - Keras reshape 输入 LSTM

转载 作者:太空宇宙 更新时间:2023-11-03 15:33:47 24 4
gpt4 key购买 nike

我正在研究一个虚拟示例,以了解 LSTM 如何使用 Keras 工作。我在 reshape 数据输入和输出的方式上遇到问题。

ValueError:输入 0 与循环层不兼容:预期 ndim=3,发现 ndim=2

import random
import numpy as np

from keras.layers import Input, LSTM, Dense
from keras.layers.wrappers import TimeDistributed
from keras.models import Model

def gen_number():
return np.random.choice([random.random(), 1], p=[0.2, 0.8])
truth_input = [gen_number() for i in range(0,2000)]
# shift input by one
truth_shifted = truth_input[1:] + [np.mean(truth_input)]
truth = np.array(truth_input)
test_ouput = np.array(truth_shifted)
truth_reshaped = truth.reshape(1, len(truth), 1)
shifted_truth_reshaped = test_ouput.reshape(1, len(test_ouput), 1)
yes = Input(shape=(len(truth_reshaped),), name = 'truth_in')
recurrent = LSTM(20, return_sequences=True, name='recurrent')(yes)
TimeDistributed_output = TimeDistributed(Dense(1), name='test_pseudo')(recurrent)
model_built = Model(input=yes, output=TimeDistributed_output)
model_built.compile(loss='mse', optimizer='adam')
model_built.fit(truth_reshaped, shifted_truth_reshaped, nb_epoch=100)

我需要怎样做才能正确输入数据?

最佳答案

yes = Input(shape=(len(truth_reshaped),), name = 'truth_in')

Len(truth_reshape) 将返回 1,因为您将其整形为 (1,2000,1)。这里第一个 1 是序列数,2000 是序列中的时间步数,第二个 1 是序列中每个元素的值数。

所以你的输入应该是

yes = Input(shape=(len(truth),1), name = 'truth_in')

这将告诉您的网络,输入将是长度为 len(truth,1) 且元素维度为 1 的序列。

关于python - Keras reshape 输入 LSTM,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42702491/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com