gpt4 book ai didi

python-3.x - Dataset api中的多线程

转载 作者:行者123 更新时间:2023-12-03 22:05:43 25 4
gpt4 key购买 nike

TL;DR:在 tensorflow 0.1.4 中使用 Dataset api 时,如何确保以多线程方式加载数据?

以前我对磁盘中的图像做了类似的事情:

filename_queue = tf.train.string_input_producer(filenames)    
image_reader = tf.WholeFileReader()
_, image_file = image_reader.read(filename_queue)
imsize = 120
image = tf.image.decode_jpeg(image_file, channels=3)
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
image_r = tf.image.resize_images(image, [imsize, imsize])
images = tf.train.shuffle_batch([image_r],
batch_size=20,
num_threads=30,
capacity=200,
min_after_dequeue=0)

这确保将有 20 个线程为下一次学习迭代准备好数据。

现在使用数据集 api,我可以执行以下操作:

dataset = tf.data.Dataset.from_tensor_slices((filenames, filenames_up, filenames_blacked))
dataset = dataset.map(parse_upscaler_corrector_batch)

在此之后我创建了一个迭代器:

sess = tf.Session();
iterator = dataset.make_initializable_iterator();
next_element = iterator.get_next();
sess.run(iterator.initializer);
value = sess.run(next_element)

变量 value 将被传递以进行进一步处理。

那么我如何确保数据是以多线程方式准备的呢?我在哪里可以阅读有关 Dataset api 和多线程数据读取的信息?

最佳答案

所以看起来实现这一点的方法如下:

dataset = dataset.map(parse_upscaler_corrector_batch, num_parallel_calls=12).prefetch(32).batch(self.ex_config.batch_size)

如果更改 num_parallel_calls=12,可以看到网络/硬盘负载和 CPU 负载都出现峰值或下降。

关于python-3.x - Dataset api中的多线程,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47653644/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com