gpt4 book ai didi

python - Keras - 如何在 fit_generator() 中使用批处理和时期?

转载 作者:太空狗 更新时间:2023-10-29 16:56:33 26 4
gpt4 key购买 nike

我有一个 8000 帧的视频,我想训练一个 Keras 模型,每批 200 帧。我有一个帧生成器,它逐帧循环播放视频并将 (3 x 480 x 640) 帧累积到形状为 (200, 3, 480, 640) -- (batch size, rgb, frame height, frame width) -- 每200帧产生XY:

import cv2
...
def _frameGenerator(videoPath, dataPath, batchSize):
"""
Yield X and Y data when the batch is filled.
"""
camera = cv2.VideoCapture(videoPath)
width = camera.get(3)
height = camera.get(4)
frameCount = int(camera.get(7)) # Number of frames in the video file.

truthData = _prepData(dataPath, frameCount)

X = np.zeros((batchSize, 3, height, width))
Y = np.zeros((batchSize, 1))

batch = 0
for frameIdx, truth in enumerate(truthData):
ret, frame = camera.read()
if ret is False: continue

batchIndex = frameIdx%batchSize

X[batchIndex] = frame
Y[batchIndex] = truth

if batchIndex == 0 and frameIdx != 0:
batch += 1
print "now yielding batch", batch
yield X, Y

这是如何运行 fit_generator() :

        batchSize = 200
print "Starting training..."
model.fit_generator(
_frameGenerator(videoPath, dataPath, batchSize),
samples_per_epoch=8000,
nb_epoch=10,
verbose=args.verbosity
)

我的理解是,当模型看到 samples_per_epoch 个样本时,一个 epoch 结束,并且 samples_per_epoch = batch size * number of batches = 200 * 40。所以在训练之后对于第 0-7999 帧的一个 epoch,下一个 epoch 将从第 0 帧开始重新训练。这是正确的吗?

使用此设置我预计每个时期有 40 个批处理(每个批处理 200 帧)从生成器传递到 fit_generator;这将是每个时期总共 8000 帧——即 samples_per_epoch=8000。然后对于后续的 epoch,fit_generator 将重新初始化生成器,以便我们从视频的开头再次开始训练。然而事实并非如此。 第一个纪元完成后(在模型记录批处理 0-24 之后),生成器从中断的地方开始。新的 epoch 不应该从训练数据集的开头重新开始吗?

如果我对 fit_generator 的理解有误,请解释。我已经阅读了文档,这个 example ,还有这些 related issues .我在 TensorFlow 后端使用 Keras v1.0.7。此问题也发布在 Keras repo 中.

最佳答案

After the first epoch is complete (after the model logs batches 0-24), the generator picks up where it left off

这是对所发生情况的准确描述。如果你想重置或倒带发电机,你必须在内部进行。请注意,keras 的行为在许多情况下非常有用。例如,您可以在看到 1/2 的数据后结束一个纪元,然后在另一半进行一个纪元,如果生成器状态被重置(这对于更密切地监控验证很有用),这是不可能的。

关于python - Keras - 如何在 fit_generator() 中使用批处理和时期?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/38936016/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com