gpt4 book ai didi

tensorflow - 遍历 keras 中所有批处理的生成器

转载 作者:行者123 更新时间:2023-12-05 05:07:48 25 4
gpt4 key购买 nike

我有一个由图像和标签组成的数据集,加载了一个生成器,例如:

generator = image_generator.flow_from_directory(batch_size=BATCH_SIZE,
directory=val_dir,
shuffle=False,
target_size=(100,100),
class_mode='categorical')

在使用 CNN 进行预测后,我想遍历所有结果并打印原始图像和预测标签。

使用:

x,y = generator.next()

我设法做到了这一点,但我受限于一批生成器中的元素数量。尝试打印更多循环超出索引。

如何使用这种方法遍历批处理以获得所有结果?

最佳答案

official link 中所述:

for e in range(epochs):
print('Epoch', e)
batches = 0
for x_batch, y_batch in datagen.flow(x_train, y_train, batch_size=32):
model.fit(x_batch, y_batch)
batches += 1
if batches >= len(x_train) / 32:
# we need to break the loop by hand because
# the generator loops indefinitely
break

您可以模仿此示例代码来获取数据生成中的每个批处理。一个重要的注意事项是,如果需要,您应该设置相同的种子以保持批处理的顺序。

关于tensorflow - 遍历 keras 中所有批处理的生成器,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58850428/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com