gpt4 book ai didi

python - 密集层的 LSTM 初始状态

转载 作者:太空狗 更新时间:2023-10-30 00:38:11 40 4
gpt4 key购买 nike

我正在对时间序列数据使用 lstm。我有关于时间序列的不依赖于时间的特征。想象一下该系列的公司股票以及非时间系列功能中的公司位置之类的东西。这不是用例,但它是相同的想法。对于此示例,我们只预测时间序列中的下一个值。

一个简单的例子是:

feature_input = Input(shape=(None, data.training_features.shape[1]))
dense_1 = Dense(4, activation='relu')(feature_input)
dense_2 = Dense(8, activation='relu')(dense_1)

series_input = Input(shape=(None, data.training_series.shape[1]))
lstm = LSTM(8)(series_input, initial_state=dense_2)
out = Dense(1, activation="sigmoid")(lstm)

model = Model(inputs=[feature_input,series_input], outputs=out)
model.compile(loss='mean_squared_error', optimizer='adam', metrics=["mape"])

但是,我只是不确定如何正确指定列表中的初始状态。我明白了

ValueError: An initial_state was passed that is not compatible with `cell.state_size`. Received `state_spec`=[<keras.engine.topology.InputSpec object at 0x11691d518>]; However `cell.state_size` is (8, 8)

我可以看出这是由 3d 批处理维度引起的。我尝试使用 Flatten、Permutation 和 Resize 图层,但我认为这是不正确的。我缺少什么以及如何连接这些层?

最佳答案

第一个问题是 LSTM(8) 层需要两个初始状态 h_0c_0,每个维度 (无,8)。这就是错误消息中“cell.state_size is (8, 8)”的含义。

如果您只有一个初始状态 dense_2,也许您可​​以切换到 GRU(它只需要 h_0)。或者,您可以将 feature_input 转换为两个初始状态。

第二个问题是h_0c_0的形状是(batch_size, 8),但是你的dense_2 的形状为 (batch_size, timesteps, 8)。在使用 dense_2 作为初始状态之前,您需要处理时间维度。

因此,也许您可​​以将输入形状更改为 (data.training_features.shape[1],) 或使用 GlobalAveragePooling1D 对时间步取平均值。

一个可行的例子是:

feature_input = Input(shape=(5,))
dense_1_h = Dense(4, activation='relu')(feature_input)
dense_2_h = Dense(8, activation='relu')(dense_1_h)
dense_1_c = Dense(4, activation='relu')(feature_input)
dense_2_c = Dense(8, activation='relu')(dense_1_c)

series_input = Input(shape=(None, 5))
lstm = LSTM(8)(series_input, initial_state=[dense_2_h, dense_2_c])
out = Dense(1, activation="sigmoid")(lstm)
model = Model(inputs=[feature_input,series_input], outputs=out)
model.compile(loss='mean_squared_error', optimizer='adam', metrics=["mape"])

关于python - 密集层的 LSTM 初始状态,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48233400/

40 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com