gpt4 book ai didi

python - 在 Keras Tensorflow 中将 TimeseriesGenerator 与多元数据集结合使用

转载 作者:行者123 更新时间:2023-12-05 05:07:15 27 4
gpt4 key购买 nike

我正在尝试对 Keras 中的 TimeseriesGenerator 的输出建模,该输出将用作 LSTM 网络的输入,但一直面临问题。数据集具有以下结构:

enter image description here

其中特征集以绿色(F1 到 F6)显示,目标变量 (T) 以红色显示。我已将包含 3170 个观察值的总数据集分成三组:

enter image description here

由于 Keras 中的 LSTM 需要三个维度的输入大小,我使用以下命令 reshape 了数据集:

        train= train.reshape((train_df.shape[0], 1, train_df.shape[1]))
validation= validation.reshape((validation.shape[0], 1, validation.shape[1]))
test= test.reshape((test.shape[0], 1, test.shape[1]))

因此,重构数据集的大小如下:

enter image description here

三个维度在哪里(样本、时间步长、特征)。但真正的问题是现在数据集何时传递给 keras 中的时间序列生成器。使用的生成器代码如下:

        generator = TimeseriesGenerator(train, train_target, length=1, batch_size=10)

TimeseriesGenerator 将数据集传递给 fit_generator,如下所示:

        model.fit_generator(generator, validation_data=(validation, validation_target),
epochs=100, verbose=0,
shuffle=False, workers=1, use_multiprocessing=True)

我在 Keras 中的 LSTM 网络配置如下:

                model = Sequential()
model.add(LSTM(200, input_shape=(10, 6), return_sequences=True))
model.add(LSTM(200, input_shape=(10, 6), return_sequences=False))
model.add(Dense(1, kernel_initializer='uniform', activation='linear'))

第一个 LSTM 层的 input_shape 是 (10,6),这意味着 10 个样本/观测值具有 6 个特征。我选择 (10,6) 的 input_shape,因为 TimeseriesGenerator 应该生成 10 个样本的 batch_size,每个样本具有 6 个特征。

但是,这会导致如下错误:

ValueError: Error when checking input: expected lstm_input to have 3 dimensions, but got array with shape (10, 1, 1, 6)

TimeseriesGenerator 生成训练集的输入大小为(10, 1, 1, 6)。生成的火车数据集有四个维度,但是,我希望 TimeseriesGenerator 生成 10 个样本的 batch_size,每个样本具有 6 个特征,即输入大小为 (10,1,6).

如何让 TimeseriesGenerator 生成大小为 (10,1,6) 的输入?

最佳答案

错误是...不要 reshape 您的训练、验证和测试数据...生成器会自动将其变为 (10, 1,6) 的形状...因为当您给定批量大小 10 时,它将需要 10批处理......具有 1 个数据和 6 个特征的长度......试试吧。它会工作..

只要给你的火车,验证,测试数据在(m,n)生成器会自己reshape

关于python - 在 Keras Tensorflow 中将 TimeseriesGenerator 与多元数据集结合使用,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59301113/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com