gpt4 book ai didi

python - Keras Stateful LSTM fit_generator 如何使用 batch_size > 1

转载 作者:行者123 更新时间:2023-11-28 19:04:42 25 4
gpt4 key购买 nike

我想使用 Keras 中的函数式 API 训练一个有状态 LSTM 网络。

拟合方法是fit_generator

我可以训练它,使用:batch_size = 1

我的输入层是:

Input(shape=(n_history, n_cols),batch_shape=(batch_size, n_history, n_cols), 
dtype='float32', name='daily_input')

生成器如下:

def training_data():
while 1:
for i in range(0,pdf_daily_data.shape[0]-n_history,1):
x = f(i)() # f(i) shape is (1, n_history, n_cols)
y = y(i)
yield (x,y)

然后拟合是:

model.fit_generator(training_data(),
steps_per_epoch=pdf_daily_data.shape[0]//batch_size,...

这很有效,训练也很好,但是非常慢,并且自 batch_size = 1

起在每个时间步都执行梯度更新

在此配置中,如何设置 batch_size > 1记住:LSTM 层有 stateful = True

最佳答案

您将必须修改您的生成器以产生您希望批处理拥有的所需数量的元素

目前您正在逐个元素地迭代数据(根据 range() 的第三个参数),获得一个单个 xy,然后生成该元素。当您返回单个元素时,您将获得 batch_size=1,因为您的 fit_generator 正在逐个元素地进行训练。

假设您希望批量大小为 10,那么您必须对数据进行切片并获取每个 10 个元素的片段,然后yield 这些片段而不是单个元素。只需确保将这些更改相应地反射(reflect)到输入层的形状上,并传递相应的 batch_size

关于python - Keras Stateful LSTM fit_generator 如何使用 batch_size > 1,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48382859/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com