gpt4 book ai didi

tensorflow - 如何从 ImageDataGenerator 获取 x_train 和 y_train?

转载 作者:行者123 更新时间:2023-12-04 17:21:22 25 4
gpt4 key购买 nike

我正在处理一些图像分类问题,并为此创建了 Y Network。 Y 网络是一种具有两个输入和一个输出的神经网络。如果我们想要拟合我们的 Tensorflow 模型,我们必须在 model.fit() 中提供 x_train 和 y_train。像这样 -

model.fit([x_train, x_train], y_train, epochs=100, batch_size=64)

但是,如果我从 ImageDataGenerator 获取数据,我该如何获取 x_trainy_train 呢?像这样 -

train_generator = train_datagen.flow_from_dataframe(... , batch_size=64, ...)

我尝试通过这种方法获得 x_train 和 y_train:

x_train, y_train = train_generator.next()

但结果 x_train 和 y_train 仅包含 64 图像,我想要我所有的 8644 图像。我无法将 batch_size 增加到 8644,因为它需要更多内存并且 Google Colab 会崩溃。我该怎么办?

最佳答案

您可以从中获取所有图像和标签的列表

class_dict=train_generator.class_indices
labels= train_generator.labels
file_names= train_generator.filenames

类字典对于将类索引与类名相关联很有用,它的形式是{class name, index} 我发现颠倒顺序以获得{index, class name}形式的字典很有用使用下面的代码

for key,value in class_dict.items():
new_dict[value]=key

因此,当您进行预测并使用 index= np.argmax(p) 获取预测的索引时,您可以从中获取相应的类名

class_name=new_dict[index]

关于tensorflow - 如何从 ImageDataGenerator 获取 x_train 和 y_train?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66011974/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com