gpt4 book ai didi

python - 与队列相比,Tensorflow 数据集非常慢

转载 作者:太空宇宙 更新时间:2023-11-03 15:49:47 26 4
gpt4 key购买 nike

使用 Dataset-API 执行相同的任务似乎比使用队列慢 10-100 倍。

这就是我试图用数据集做的事情:

dataset = tf.data.TFRecordDataset(filenames).repeat()
dataset = dataset.batch(100)
dataset = dataset.map(_parse_function)
dataset = dataset.prefetch(1000)
d = dataset.make_one_shot_iterator()

%timeit -n 200 sess.run(d.get_next())

还有队列:

filename_queue = tf.train.string_input_producer(filenames, capacity=1)

reader = tf.TFRecordReader()
_, serialized_example = reader.read_up_to(filename_queue, 100)

features = _parse_function(serialized_example)

coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
tf.train.start_queue_runners()

%timeit -n 200 sess.run(features)

观察结果:

数据集:每个循环 23.6 毫秒 ± 8.73 毫秒(7 次运行的平均值 ± 标准偏差,每次 200 次循环)

队列:每个循环 481 µs ± 91.7 µs(7 次运行的平均值 ± 标准偏差,每次 200 次循环)

为什么会这样?如何让数据集工作得更快?


使用 tensorflow 1.4 和 python 3.5

要重现的完整代码:

import tensorflow as tf
import numpy as np
import glob
import os


def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


def create_data(i):
tfrecords_filename = '_temp/dstest/tt%d.tfr' % i

writer = tf.python_io.TFRecordWriter(tfrecords_filename)

for j in range(1000):
f = tf.train.Features(feature={
'x': _int64_feature([j]),
"y": _int64_feature(np.random.randint(5, 100, size=np.random.randint(6)))
})

example = tf.train.Example(features=f)
writer.write(example.SerializeToString())

writer.close()
return tfrecords_filename


def _parse_function(example_proto):
features = {
"x": tf.FixedLenFeature((), tf.int64),
"y": tf.FixedLenSequenceFeature((), tf.int64, allow_missing=True)
}
parsed_features = tf.parse_example(example_proto, features)
return parsed_features


os.makedirs("_temp/dstest", exist_ok=True)
sess = tf.InteractiveSession()

filenames = [create_data(i) for i in range(5)]

#### DATASET
dataset = tf.data.TFRecordDataset(filenames).repeat()
dataset = dataset.batch(100)
dataset = dataset.map(_parse_function)
dataset = dataset.prefetch(1000)
d = dataset.make_one_shot_iterator()

%timeit -n 200 sess.run(d.get_next())

#### QUEUE
filename_queue = tf.train.string_input_producer(filenames, capacity=1)

reader = tf.TFRecordReader()
_, serialized_example = reader.read_up_to(filename_queue, 100)

features = _parse_function(serialized_example)

coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
tf.train.start_queue_runners()

%timeit -n 200 sess.run(features)

coord.request_stop()
coord.join(threads)

最佳答案

哦,我想通了。我不应该多次调用 d.get_next()

当我把它改成:

d = dataset.make_one_shot_iterator().get_next()
%timeit -n 200 sess.run(d)

然后速度与队列版本相似,即使没有预取。

并且需要 sess.run 调用的结果总是不同的。

关于python - 与队列相比,Tensorflow 数据集非常慢,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47652590/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com