gpt4 book ai didi

python - 如何在 tfds.load() 之后在 TensorFlow 2.0 中应用数据增强

转载 作者:行者123 更新时间:2023-12-04 11:15:32 33 4
gpt4 key购买 nike

我正在关注 this guide .

它展示了如何使用 tfds.load() 从新的 TensorFlow 数据集下载数据集。方法:

import tensorflow_datasets as tfds    
SPLIT_WEIGHTS = (8, 1, 1)
splits = tfds.Split.TRAIN.subsplit(weighted=SPLIT_WEIGHTS)

(raw_train, raw_validation, raw_test), metadata = tfds.load(
'cats_vs_dogs', split=list(splits),
with_info=True, as_supervised=True)

接下来的步骤展示了如何使用 map 方法将函数应用于数据集中的每个项目:
def format_example(image, label):
image = tf.cast(image, tf.float32)
image = image / 255.0
# Resize the image if required
image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
return image, label

train = raw_train.map(format_example)
validation = raw_validation.map(format_example)
test = raw_test.map(format_example)

然后访问我们可以使用的元素:
for features in ds_train.take(1):
image, label = features["image"], features["label"]

或者
for example in tfds.as_numpy(train_ds):
numpy_images, numpy_labels = example["image"], example["label"]

但是,该指南没有提及有关数据增强的任何内容。我想使用类似于 Keras 的 ImageDataGenerator 类的实时数据增强。我尝试使用:
if np.random.rand() > 0.5:
image = tf.image.flip_left_right(image)

以及 format_example() 中的其他类似增强函数但是,我如何验证它是否正在执行实时增强而不是替换数据集中的原始图像?

我可以通过传递 batch_size=-1 将完整的数据集转换为 Numpy 数组至 tfds.load()然后使用 tfds.as_numpy()但是,这会将所有不需要的图像加载到内存中。我应该可以使用 train = train.prefetch(tf.data.experimental.AUTOTUNE)为下一个训练循环加载足够的数据。

最佳答案

您从错误的方向处理问题。

首先,使用tfds.load下载数据, cifar10例如(为简单起见,我们将使用默认的 TRAINTEST 拆分):

import tensorflow_datasets as tfds

dataloader = tfds.load("cifar10", as_supervised=True)
train, test = dataloader["train"], dataloader["test"]

(您可以使用自定义 tfds.Split 对象来创建验证数据集或其他, see documentation )
traintesttf.data.Dataset对象,以便您可以使用 map , apply , batch和类似的功能。

下面是一个示例,我将在这里(主要使用 tf.image ):
  • 将每个图像转换为 tf.float640-1范围(不要使用官方文档中的这个愚蠢的片段,这样可以确保正确的图像格式)
  • cache()结果,因为这些结果可以在每个 repeat 之后重复使用
  • 随机翻转 left_to_right每张图片
  • 随机改变图像的对比度
  • 混洗数据和批处理
  • 重要提示:当数据集耗尽时重复所有步骤。这意味着在一个 epoch 之后,所有上述转换都会再次应用(除了那些被缓存的)。

  • 这是执行上述操作的代码(您可以将 lambda s 更改为仿函数或函数):
    train = train.map(
    lambda image, label: (tf.image.convert_image_dtype(image, tf.float32), label)
    ).cache().map(
    lambda image, label: (tf.image.random_flip_left_right(image), label)
    ).map(
    lambda image, label: (tf.image.random_contrast(image, lower=0.0, upper=1.0), label)
    ).shuffle(
    100
    ).batch(
    64
    ).repeat()

    这样的 tf.data.Dataset可以直接传给Keras的 fit , evaluatepredict方法。

    验证它实际上是这样工作的

    我看你对我的解释很怀疑,让我们来看一个例子:

    1. 获取一小部分数据

    这是获取单个元素的一种方法,公认不可读且不直观,但是如果您对 Tensorflow 执行任何操作,应该没问题。 :
    # Horrible API is horrible
    element = tfds.load(
    # Take one percent of test and take 1 element from it
    "cifar10",
    as_supervised=True,
    split=tfds.Split.TEST.subsplit(tfds.percent[:1]),
    ).take(1)

    2.重复数据,检查是否相同:

    使用 Tensorflow 2.0人们实际上可以在没有愚蠢的解决方法的情况下做到这一点(几乎):
    element = element.repeat(2)
    # You can iterate through tf.data.Dataset now, finally...
    images = [image[0] for image in element]
    print(f"Are the same: {tf.reduce_all(tf.equal(images[0], images[1]))}")

    它不出所料地返回:
    Are the same: True

    3.用随机增广检查每次重复后数据是否不同

    下面的代码片段 repeat s 单个元素 5 次并检查哪些相等哪些不同。
    element = (
    tfds.load(
    # Take one percent of test and take 1 element
    "cifar10",
    as_supervised=True,
    split=tfds.Split.TEST.subsplit(tfds.percent[:1]),
    )
    .take(1)
    .map(lambda image, label: (tf.image.random_flip_left_right(image), label))
    .repeat(5)
    )

    images = [image[0] for image in element]

    for i in range(len(images)):
    for j in range(i, len(images)):
    print(
    f"{i} same as {j}: {tf.reduce_all(tf.equal(images[i], images[j]))}"
    )

    输出(在我的情况下,每次运行都会不同):
    0 same as 0: True
    0 same as 1: False
    0 same as 2: True
    0 same as 3: False
    0 same as 4: False
    1 same as 1: True
    1 same as 2: False
    1 same as 3: True
    1 same as 4: True
    2 same as 2: True
    2 same as 3: False
    2 same as 4: False
    3 same as 3: True
    3 same as 4: True
    4 same as 4: True

    您可以将这些图像中的每一个转换到 numpy也可以使用 skimage.io.imshow 自己查看图像, matplotlib.pyplot.imshow 或其他替代方案。

    实时数据增强的另一个可视化示例

    This answer使用 Tensorboard 提供关于数据增强的更全面和可读的 View 和 MNIST ,可能想检查一下(是的,无耻的插件,但我猜很有用)。

    关于python - 如何在 tfds.load() 之后在 TensorFlow 2.0 中应用数据增强,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55141076/

    33 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com