gpt4 book ai didi

python - Tensorflow:在cpu上的多个线程中加载数据

转载 作者:太空狗 更新时间:2023-10-29 20:22:24 25 4
gpt4 key购买 nike

我有一个 python 类 SceneGenerator,它有多个用于预处理的成员函数和一个生成器函数 generate_data()。基本结构是这样的:

class SceneGenerator(object):
def __init__(self):
# some inits

def generate_data(self):
"""
Generator. Yield data X and labels y after some preprocessing
"""
while True:
# opening files, selecting data
X,y = self.preprocess(some_params, filenames, ...)

yield X, y

我在keras model.fit_generator() 函数中使用类成员函数sceneGenerator.generate_data() 从磁盘读取数据,预处理并产生它。在 keras 中,这是在多个 CPU 线程上完成的,如果 model.fit_generator()workers 参数设置为 > 1。

我现在想在 tensorflow 中使用相同的 SceneGenerator 类。我目前的做法是:

sceneGenerator = SceneGenerator(some_params)
for X, y in sceneGenerator.generate_data():

feed_dict = {ops['data']: X,
ops['labels']: y,
ops['is_training_pl']: True
}
summary, step, _, loss, prediction = sess.run([optimization_op, loss_op, pred_op],
feed_dict=feed_dict)

然而,这很慢并且不使用多线程。我找到了 tf.data.Dataset api 与一些 documentation ,但我未能实现这些方法。

编辑:请注意,我不处理图像,因此带有文件路径等的图像加载机制在这里不起作用。我的 SceneGenerator 从 hdf5 文件加载数据。但不是完整的数据集,而是——取决于初始化参数——只是数据集的一部分。我很乐意保持生成器功能不变,并了解如何将此生成器直接用作 tensorflow 的输入并在 CPU 上的多个线程上运行。将 hdf5 文件中的数据重写为 csv 不是一个好的选择,因为它会重复大量数据。

编辑 2::我认为与此类似的东西可能会有所帮助:parallelising tf.data.Dataset.from_generator

最佳答案

假设您使用的是最新的 Tensorflow(撰写本文时为 1.4),您可以保留生成器并使用 tf.data.* API如下(我为线程数、预取缓冲区大小、批处理大小和输出数据类型选择了任意值):

NUM_THREADS = 5
sceneGen = SceneGenerator()
dataset = tf.data.Dataset.from_generator(sceneGen.generate_data, output_types=(tf.float32, tf.int32))
dataset = dataset.map(lambda x,y : (x,y), num_parallel_calls=NUM_THREADS).prefetch(buffer_size=1000)
dataset = dataset.batch(42)
X, y = dataset.make_one_shot_iterator().get_next()

为了表明它实际上是从生成器中提取的多个线程,我将您的类修改如下:

import threading    
class SceneGenerator(object):
def __init__(self):
# some inits
pass

def generate_data(self):
"""
Generator. Yield data X and labels y after some preprocessing
"""
while True:
# opening files, selecting data
X,y = threading.get_ident(), 2 #self.preprocess(some_params, filenames, ...)
yield X, y

这样,创建一个 Tensorflow session 并获取一批显示获取数据的线程的线程 ID。在我的电脑上,运行:

sess = tf.Session()
print(sess.run([X, y]))

打印

[array([  8460.,   8460.,   8460.,  15912.,  16200.,  16200.,   8460.,
15912., 16200., 8460., 15912., 16200., 16200., 8460.,
15912., 15912., 8460., 8460., 6552., 15912., 15912.,
8460., 8460., 15912., 9956., 16200., 9956., 16200.,
15912., 15912., 9956., 16200., 15912., 16200., 16200.,
16200., 6552., 16200., 16200., 9956., 6552., 6552.], dtype=float32),
array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])]

注意:您可能想尝试删除 map 调用(我们只使用它来拥有多个线程)并检查 prefetch 的缓冲区足以消除输入管道中的瓶颈(即使只有一个线程,输入预处理通常比实际图形执行速度更快,因此缓冲区足以让预处理尽可能快地进行)。

关于python - Tensorflow:在cpu上的多个线程中加载数据,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47568998/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com