gpt4 book ai didi

tensorflow - keras中的fit_generator : where is the batch_size specified?

转载 作者:行者123 更新时间:2023-12-03 16:37:18 26 4
gpt4 key购买 nike

嗨,我不了解 keras fit_generator 文档。

我希望我的困惑是理性的。

有一个batch_size还有分批训练的概念。使用 model_fit() ,我指定一个 batch_size 128 个。

对我来说,这意味着我的数据集将一次输入 128 个样本,从而大大减轻了内存。只要我有时间等待,它就应该允许训练 1 亿个样本数据集。毕竟,keras 一次只能“处理”128 个样本。对?

但我高度怀疑指定 batch_size独自一人不会做我想做的事。大量内存仍在使用中。为了我的目标,我需要分批训练 128 个示例。

所以我猜这就是 fit_generator做。我真的很想问为什么不batch_size真的像它的名字所暗示的那样工作吗?

更重要的是,如果 fit_generator需要,我在哪里指定 batch_size ?文档说无限循环。
生成器对每一行循环一次。我如何一次循环超过 128 个样本并记住我上次停止的位置并在下次 keras 要求下一批的起始行号时记忆它(在第一批完成后将是第 129 行)。

最佳答案

您将需要在生成器内部以某种方式处理批量大小。这是一个生成随机批处理的示例:

import numpy as np
data = np.arange(100)
data_lab = data%2
wholeData = np.array([data, data_lab])
wholeData = wholeData.T

def data_generator(all_data, batch_size = 20):

while True:

idx = np.random.randint(len(all_data), size=batch_size)

# Assuming the last column contains labels
batch_x = all_data[idx, :-1]
batch_y = all_data[idx, -1]

# Return a tuple of (Xs,Ys) to feed the model
yield(batch_x, batch_y)

print([x for x in data_generator(wholeData)])

关于tensorflow - keras中的fit_generator : where is the batch_size specified?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43780193/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com