gpt4 book ai didi

python - 如何在 TensorFlow 中使用我自己的数据将图像拆分为测试集和训练集

转载 作者:行者123 更新时间:2023-12-04 01:34:11 32 4
gpt4 key购买 nike

我在这里有点困惑...我刚刚花了最后一个小时阅读如何在 TensorFlow 中将我的数据集拆分为测试/训练。我正在按照本教程导入图像:https://www.tensorflow.org/tutorials/load_data/images .显然可以使用 sklearn 拆分为训练/测试:model_selection.train_test_split

但我的问题是:我什么时候将我的数据集拆分为训练/测试。我已经用我的数据集完成了这个(见下文),现在呢?我该如何拆分它?我是否必须在将文件加载为 tf.data.Dataset 之前执行此操作?

# determine names of classes
CLASS_NAMES = np.array([item.name for item in data_dir.glob('*') if item.name != "LICENSE.txt"])
print(CLASS_NAMES)

# count images
image_count = len(list(data_dir.glob('*/*.png')))
print(image_count)


# load the files as a tf.data.Dataset
list_ds = tf.data.Dataset.list_files(str(cwd + '/train/' + '*/*'))

此外,我的数据结构如下所示。没有测试文件夹,没有 val 文件夹。我需要从该训练集中抽取 20% 的样本进行测试。

train
|__ class 1
|__ class 2
|__ class 3

最佳答案

您可以使用tf.keras.preprocessing.image.ImageDataGenerator:

image_generator = tf.keras.preprocessing.image.ImageDataGenerator(validation_split=0.2)
train_data_gen = image_generator.flow_from_directory(directory='train',
subset='training')
val_data_gen = image_generator.flow_from_directory(directory='train',
subset='validation')

请注意,您可能需要设置其他 data-related parameters为您的发电机。

更新:您可以通过skip()take() 获取数据集的两个切片:

val_data = data.take(val_data_size)
train_data = data.skip(val_data_size)

关于python - 如何在 TensorFlow 中使用我自己的数据将图像拆分为测试集和训练集,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60130918/

32 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com