gpt4 book ai didi

python-3.x - Keras 中的 fit_generator 是否应该在每个纪元后重置生成器?

转载 作者:行者123 更新时间:2023-11-30 08:27:55 28 4
gpt4 key购买 nike

我尝试将 fit_generator 与自定义生成器一起使用来读取对于内存来说太大的数据。我想要训练 125 万行,因此我让生成器一次生成 50,000 行。 fit_generator 有 25 个steps_per_epoch,我认为这将为每个周期带来 1.25MM 的增量。我添加了一个打印语句,以便我可以看到该进程正在执行多少偏移量,并且我发现当它进入第 2 纪元的几步时,它超出了最大值。该文件中总共有 175 万条记录,并且一次它通过了 10 个步骤,在 create_feature_matrix 调用中出现索引错误(因为它没有引入任何行)。

def get_next_data_batch():
import gc
nrows = 50000
skiprows = 0

while True:
d = pd.read_csv(file_loc,skiprows=range(1,skiprows),nrows=nrows,index_col=0)
print(skiprows)
x,y = create_feature_matrix(d)
yield x,y
skiprows = skiprows + nrows
gc.collect()
get_data = get_next_data_batch()

... set up a Keras NN ...

model.fit_generator(get_next_data_batch(), epochs=100,steps_per_epoch=25,verbose=1,workers=4,callbacks=callbacks_list)

我是否使用了 fit_generator 错误,或者是否需要对我的自定义生成器进行一些更改才能使其正常工作?

最佳答案

否 - fit_generator 不会重置生成器,它只是继续调用它。为了实现您想要的行为,您可以尝试以下操作:

def get_next_data_batch(nb_of_calls_before_reset=25):
import gc
nrows = 50000
skiprows = 0
nb_calls = 0

while True:
d = pd.read_csv(file_loc,skiprows=range(1,skiprows),nrows=nrows,index_col=0)
print(skiprows)
x,y = create_feature_matrix(d)
yield x,y
nb_calls += 1
if nb_calls == nb_of_calls_before_reset:
skiprows = 0
else:
skiprows = skiprows + nrows
gc.collect()

关于python-3.x - Keras 中的 fit_generator 是否应该在每个纪元后重置生成器?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48729107/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com