gpt4 book ai didi

tensorflow - 如何在 Tensorflow Estimator 的 input_fn 中执行数据增强

转载 作者:行者123 更新时间:2023-11-30 09:17:27 26 4
gpt4 key购买 nike

使用 Tensorflow 的 Estimator API,我应该在管道中的哪个点执行数据增强?

据此官方Tensorflow guide ,执行数据增强的一个位置是 input_fn:

def parse_fn(example):
"Parse TFExample records and perform simple data augmentation."
example_fmt = {
"image": tf.FixedLengthFeature((), tf.string, ""),
"label": tf.FixedLengthFeature((), tf.int64, -1)
}
parsed = tf.parse_single_example(example, example_fmt)
image = tf.image.decode_image(parsed["image"])

# augments image using slice, reshape, resize_bilinear
# |
# |
# |
# v
image = _augment_helper(image)

return image, parsed["label"]

def input_fn():
files = tf.data.Dataset.list_files("/path/to/dataset/train-*.tfrecord")
dataset = files.interleave(tf.data.TFRecordDataset)
dataset = dataset.map(map_func=parse_fn)
# ...
return dataset

我的问题

如果我在 input_fn 内执行数据增强,parse_fn 是否返回单个示例或包含原始输入图像 + 所有增强变体的批处理?如果它只返回一个[增强]示例,我如何确保数据集中的所有图像以及所有变体都以其未增强的形式使用?

最佳答案

如果您在数据集上使用迭代器,则您的 _augment_helper 函数将在输入的每个数据 block 上的数据集的每次迭代中被调用(就像您在 dataset.map 中调用 parse_fn 一样)

将代码更改为

  ds_iter = dataset.make_one_shot_iterator()
ds_iter = ds_iter.get_next()
return ds_iter

我已经用一个简单的增强函数对此进行了测试

  def _augment_helper(image):
print(image.shape)
image = tf.image.random_brightness(image,255.0, 1)
image = tf.clip_by_value(image, 0.0, 255.0)
return image

将 255.0 更改为数据集中的最大值,我使用 255.0 作为示例的数据集采用 8 位像素值

关于tensorflow - 如何在 Tensorflow Estimator 的 input_fn 中执行数据增强,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51594027/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com