gpt4 book ai didi

python - Keras fit_generator 一次训练一个样本,而我从生成器中生成多个样本

转载 作者:行者123 更新时间:2023-11-30 09:42:53 24 4
gpt4 key购买 nike

我正在使用 keras 训练模型。我尝试了“fit”和“fit_generator”功能。我不明白为什么性能有很大差异,可能是我做错了什么。这是我第一次编写batch_generator代码。

给定批量大小为 10,我观察到的是使用函数时 -

适合:训练速度更快(每个时期大约 3 分钟),详细计数随着批处理大小的倍数增加(此处为 10)
样本-80/7632 [................................] - 预计到达时间:4:31 - 损失:2.2072 - 加速:0.4375

fit_generator:训练速度慢得多(每轮 10 分钟),详细计数一次增加 1(不等于批量大小)
样本-37/7632 [................................] - 预计到达时间:42:25 - 损失:2.1845 - 加速:0.3676

正如您所看到的,对于同一数据集的 fit_generator 来说,ETA 太高了。并且fit_generator每次增加1,而fit则以10的倍数增加

生成器:

def batch_generator(X ,y, batch_size=10):
from sklearn.utils import shuffle

batch_count = int(len(X) / batch_size)
extra = len(X) - (batch_count * batch_size)

while 1:
#shuffle X and y
X_train, y_train = shuffle(X,y)

#Yeild Batches
for i in range(1, batch_count):
batch_start = (i-1) * batch_size
batch_end = i * batch_size
X_batch = X_train[batch_start: batch_end]
y_batch = y_train[batch_start: batch_end]
yield X_batch, y_batch

#Yeild Remaining Data less than batch size
if(extra > 0):
batch_start = batch_count * batch_size
X_batch = X_train[batch_start: -1]
y_batch = y_train[batch_start: -1]
yield X_batch, y_batch

拟合函数:

model.fit_generator(batch_generator(X, y, 10),
verbose = 1,
samples_per_epoch = len(X),
epochs = 20,
validation_data = (X_test, y_test),
callbacks = callbacks_list)

谁能解释一下为什么会发生这种情况?

最佳答案

fit_generator 不使用样本,它使用步骤,您正在使用带有 samples_per_epoch 参数的旧 Keras API,这是不正确的并且会产生错误的结果。正确的 fit_generator 调用是:

model.fit_generator(batch_generator(X, y, 10),
verbose = 1,
steps_per_epoch = int(len(X) / batch_size),
epochs = 20,
validation_data = (X_test, y_test),
callbacks = callbacks_list)

steps_per_epoch 控制在声明纪元​​结束之前要使用多少步(调用生成器)。应将其设置为总样本数除以批量大小。对于fit_generator,进度条中的索引将引用步骤(批处理),而不是样本,因此您无法将它们直接与fit进度条中的索引进行比较。

关于python - Keras fit_generator 一次训练一个样本,而我从生成器中生成多个样本,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56598285/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com