gpt4 book ai didi

tensorflow - tf.data 或 tf.keras.utils.Sequence。提高 tf.data 的效率?

转载 作者:行者123 更新时间:2023-12-04 13:55:39 28 4
gpt4 key购买 nike

我正在尝试使用自动编码器开发图像着色器。有 13000 张训练图像。如果我使用 tf.data,每个 epoch 大约需要 45 分钟,如果我使用 tf.utils.keras.Sequence 大约需要 25 分钟。但是,使用 Sequence 存在死锁的风险。如何改进 tf.data?我尝试了一些东西,但它们似乎没有任何改善。
tf.data 1

image_path_list = glob.glob('datasets/imagenette/*')
data = tf.data.Dataset.list_files(image_path_list)

def tf_rgb2lab(image):
im_shape = image.shape
[image,] = tf.py_function(color.rgb2lab, [image], [tf.float32])
image.set_shape(im_shape)
return image

def preprocess(path):
image = tf.io.read_file(path)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, [224, 224])
image = tf_rgb2lab(image)
L = image[:,:,0]/100.
ab = image[:,:,1:]/128.
input = tf.stack([L,L,L], axis=2)
return input, ab

train_ds = data.repeat().map(preprocess, AUTOTUNE).batch(32).prefetch(AUTOTUNE)
tf.data 2
AUTOTUNE = tf.data.experimental.AUTOTUNE

def tf_rgb2lab(image):
im_shape = image.shape
[image,] = tf.py_function(color.rgb2lab, [image], [tf.float32])
image.set_shape(im_shape)
return image

def split_for_feed(image):
L = image[:,:,:,0]/100.
ab = image[:,:,:,1:]/128.
input = tf.stack([L,L,L], axis=-1)
return input, ab

def read_images(path):
image = tf.io.read_file(path)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, [224, 224])
image = tf_rgb2lab(image)
return image

data2 = data.repeat().map(read_images, AUTOTUNE).batch(32)
train_ds = data2.map(split_for_feed, AUTOTUNE).prefetch(AUTOTUNE)
序列
class ImageGenerator(tf.keras.utils.Sequence):
def __init__(self, image_filenames, batch_size):
self.image_filenames = image_filenames
self.batch_size = batch_size

def __len__(self):
return math.ceil(len(self.image_filenames) / self.batch_size)

def __getitem__(self, idx):
batch = self.image_filenames[idx * self.batch_size : (idx + 1) * self.batch_size]
X_batch = []
y_batch = []
for file_name in batch:
file_name = 'datasets/imagenette/' + file_name
try:
color_image = transform.resize(io.imread(file_name),(224,224))
lab_image = color.rgb2lab(color_image)
L = lab_image[:,:,0]/100.
ab = lab_image[:,:,1:]/128.
X_batch.append(np.stack((L,L,L), axis=2))
y_batch.append(ab)
except:
pass
return np.array(X_batch), np.array(y_batch)

最佳答案

如果您的数据适合内存,请尝试缓存预处理。代替

train_ds = data.repeat().map(preprocess, AUTOTUNE).batch(32).prefetch(AUTOTUNE)
train_ds = data.map(preprocess, AUTOTUNE).batch(32).cache().repeat().prefetch(AUTOTUNE)
这样您只需解析每个文件一次,而不是重复解析。
如果您希望进一步优化管道,请考虑使用 TF Profiler ,它可以准确告诉您数据集的每个部分花费了多少时间,以便您找到瓶颈并解决它。

关于tensorflow - tf.data 或 tf.keras.utils.Sequence。提高 tf.data 的效率?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63113565/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com