gpt4 book ai didi

python - 数据增强完成后会发生什么?

转载 作者:行者123 更新时间:2023-11-30 09:19:35 26 4
gpt4 key购买 nike

我使用 Kaggle 的“狗与猫”date set ,并按照TensorFlow的cifar-10教程(为了方便我没有使用权重衰减,移动平均和L2损失),我已经成功地训练了我的网络,但是当我将数据增强部分添加到我的代码中时,奇怪的事情发生了,即使经过数千步,损失也从未减少(在添加之前,一切都很好)。代码如下:

def get_batch(image, label, image_w, image_h, batch_size, capacity, test_flag=False):
'''
Args:
image: list type
label: list type
image_w: image width
image_h: image height
batch_size: batch size
capacity: the maximum elements in queue
test_flag: create training batch or test batch
Returns:
image_batch: 4D tensor [batch_size, width, height, 3], dtype=tf.float32
label_batch: 1D tensor [batch_size], dtype=tf.int32
'''

image = tf.cast(image, tf.string)
label = tf.cast(label, tf.int32)

# make an input queue
input_queue = tf.train.slice_input_producer([image, label])

label = input_queue[1]
image_contents = tf.read_file(input_queue[0])
image = tf.image.decode_jpeg(image_contents, channels=3)

####################################################################
# Data argumentation should go to here
# but when we want to do test, stay the images what they are

if not test_flag:
image = tf.image.resize_image_with_crop_or_pad(image, RESIZED_IMG, RESIZED_IMG)
# Randomly crop a [height, width] section of the image.
distorted_image = tf.random_crop(image, [image_w, image_h, 3])

# Randomly flip the image horizontally.
distorted_image = tf.image.random_flip_left_right(distorted_image)

# Because these operations are not commutative, consider randomizing
# the order their operation.
# NOTE: since per_image_standardization zeros the mean and makes
# the stddev unit, this likely has no effect see tensorflow#1458.
distorted_image = tf.image.random_brightness(distorted_image, max_delta=63)

image = tf.image.random_contrast(distorted_image, lower=0.2, upper=1.8)
else:
image = tf.image.resize_image_with_crop_or_pad(image, image_w, image_h)

######################################################################

# Subtract off the mean and divide by the variance of the pixels.
image = tf.image.per_image_standardization(image)
# Set the shapes of tensors.
image.set_shape([image_h, image_w, 3])
# label.set_shape([1])

image_batch, label_batch = tf.train.batch([image, label],
batch_size=batch_size,
num_threads=64,
capacity=capacity)

label_batch = tf.reshape(label_batch, [batch_size])
image_batch = tf.cast(image_batch, tf.float32)

return image_batch, label_batch

最佳答案

确保您使用的限制(例如,亮度的 max_delta=63、对比度的 upper=1.8)足够低,以便图像仍可识别。其他问题之一可能是一遍又一遍地应用增强,因此经过几次迭代后它完全扭曲了(尽管我没有在您的代码片段中发现这个错误)。

我建议您将数据可视化添加到tensorboard中。要可视化图像,请使用 tf.summary.image方法。您将能够清楚地看到增强的结果。

tf.summary.image('input', image_batch, 10)

This gist可以作为例子。

关于python - 数据增强完成后会发生什么?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45002525/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com