gpt4 book ai didi

python - ndarray 到 TFRecord 的缓慢序列化

转载 作者:行者123 更新时间:2023-12-04 04:08:08 25 4
gpt4 key购买 nike

我想将大型 numpy ndarray 序列化为 TFRecord。问题是,这个过程非常缓慢。对于大小为 (1000000, 65) 的数组,它需要将近一分钟的时间。将其序列化为其他二进制格式(HDF5、npy、parquet...)只需不到一秒钟的时间。我很确定有一种更快的方法来序列化它,但我就是想不出来。

import numpy as np
import tensorflow as tf

X = np.random.randn(1000000, 65)

def write_tf_dataset(data: np.ndarray, path: str):
with tf.io.TFRecordWriter(path=path) as writer:
for record in data:
feature = {'X': tf.train.Feature(float_list=tf.train.FloatList(value=record[:42])),
'Y': tf.train.Feature(float_list=tf.train.FloatList(value=record[42:64])),
'Z': tf.train.Feature(float_list=tf.train.FloatList(value=[record[64]]))}
example = tf.train.Example(features=tf.train.Features(feature=feature))
serialized = example.SerializeToString()
writer.write(serialized)

write_tf_dataset(X, 'X.tfrecord')

如何提高 write_tf_dataset 的性能?我的 X 的大小比代码段中的大 200 倍。

我不是第一个提示 TFRecord 性能低下的人。基于this Tensorflow Github issue我做了第二个版本的函数:

import pickle

def write_tf_dataset(data: np.ndarray, path: str):
with tf.io.TFRecordWriter(path=path) as writer:
for record in data:
feature = {
'X': tf.io.serialize_tensor(record[:42]).numpy(),
'Y': tf.io.serialize_tensor(record[42:64]).numpy(),
'Z': tf.io.serialize_tensor(record[64]).numpy(),
}
serialized = pickle.dumps(feature)
writer.write(serialized)

...但如果表现更差。想法?

最佳答案

解决方法是使用 multiprocessing 包。您可以从多个进程写入同一个 TFRecord 文件,或者让每个进程写入不同的文件(我认为推荐使用多个(小)TFRecord 的方法,而不是单个(大)文件,因为它读取速度更快来自多个来源):

import multiprocessing
import os

import numpy as np
import tensorflow as tf


def serialize_example(record):
feature = {
"X": tf.train.Feature(float_list=tf.train.FloatList(value=record[:42])),
"Y": tf.train.Feature(float_list=tf.train.FloatList(value=record[42:64])),
"Z": tf.train.Feature(float_list=tf.train.FloatList(value=[record[64]])),
}
example = tf.train.Example(features=tf.train.Features(feature=feature))
return example.SerializeToString()


def write_tfrecord(tfrecord_path, records):
with tf.io.TFRecordWriter(tfrecord_path) as writer:
for item in records:
serialized = serialize_example(item)
writer.write(serialized)


if __name__ == "__main__":
np.random.seed(1234)
data = np.random.randn(1000000, 65)

# Option 1: write to a single file
tfrecord_path = "/home/appuser/data/data.tfrecord"
p = multiprocessing.Pool(4)
with tf.io.TFRecordWriter(tfrecord_path) as writer:
for example in p.map(serialize_example, data):
writer.write(example)

# Option 2: write to multiple files
procs = []
n_shard = 4
num_per_shard = int(np.ceil(len(data) / n_shard))
for shard_id in range(n_shard):
filename = f"data_{shard_id + 1:04d}_of_{n_shard:04d}.tfrecord"
tfrecord_path = os.path.join("/home/appuser/data", filename)

start_index = shard_id * num_per_shard
end_index = min((shard_id + 1) * num_per_shard, len(data))

args = (tfrecord_path, data[start_index:end_index])
p = multiprocessing.Process(target=write_tfrecord, args=args)
p.start()
procs.append(p)

for proc in procs:
proc.join()

关于python - ndarray 到 TFRecord 的缓慢序列化,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62174662/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com