gpt4 book ai didi

python - 以多线程方式加载多个 npz 文件

转载 作者:行者123 更新时间:2023-12-01 04:07:56 27 4
gpt4 key购买 nike

我有几个 .npz 文件。所有 .npz 文件都具有相同的结构:每个文件仅包含两个变量,且变量名始终相同。截至目前,我只需循环所有 .npz 文件,检索两个变量值并将它们附加到某个全局变量中:

# Let's assume there are 100 npz files
x_train = []
y_train = []
for npz_file_number in range(100):
data = dict(np.load('{0:04d}.npz'.format(npz_file_number)))
x_train.append(data['x'])
y_train.append(data['y'])

需要一段时间,瓶颈是CPU。 xy 变量附加到 x_trainy_train 变量的顺序并不重要。

有什么方法可以多线程加载多个 .npz 文件吗?

最佳答案

我对 @Brent Washburne 的评论感到惊讶,并决定自己尝试一下。我认为一般问题有两个:

首先,读取数据通常受 IO 限制,因此编写多线程代码通常不会产生很高的性能增益。其次,由于语言本身的设计,在Python中进行共享内存并行化本身就很困难。与原生 c 相比,开销要大得多。

但是让我们看看我们能做什么。

# some imports
import numpy as np
import glob
from multiprocessing import Pool
import os

# creating some temporary data
tmp_dir = os.path.join('tmp', 'nptest')
if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir)
for i in range(100):
x = np.random.rand(10000, 50)
file_path = os.path.join(tmp_dir, '%05d.npz' % i)
np.savez_compressed(file_path, x=x)

def read_x(path):
with np.load(path) as data:
return data["x"]

def serial_read(files):
x_list = list(map(read_x, files))
return x_list

def parallel_read(files):
with Pool() as pool:
x_list = pool.map(read_x, files)
return x_list

好了,东西准备得够多了。让我们了解一下时间。

files = glob.glob(os.path.join(tmp_dir, '*.npz'))

%timeit x_serial = serial_read(files)
# 1 loops, best of 3: 7.04 s per loop

%timeit x_parallel = parallel_read(files)
# 1 loops, best of 3: 3.56 s per loop

np.allclose(x_serial, x_parallel)
# True

它实际上看起来像是一个不错的加速。我使用两个真实核心和两个超线程核心。

<小时/>

要一次运行所有内容并为其计时,您可以执行以下脚本:

from __future__ import print_function
from __future__ import division

# some imports
import numpy as np
import glob
import sys
import multiprocessing
import os
import timeit

# creating some temporary data
tmp_dir = os.path.join('tmp', 'nptest')
if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir)
for i in range(100):
x = np.random.rand(10000, 50)
file_path = os.path.join(tmp_dir, '%05d.npz' % i)
np.savez_compressed(file_path, x=x)

def read_x(path):
data = dict(np.load(path))
return data['x']

def serial_read(files):
x_list = list(map(read_x, files))
return x_list

def parallel_read(files):
pool = multiprocessing.Pool(processes=4)
x_list = pool.map(read_x, files)
return x_list


files = glob.glob(os.path.join(tmp_dir, '*.npz'))
#files = files[0:5] # to test on a subset of the npz files

# Timing:
timeit_runs = 5

timer = timeit.Timer(lambda: serial_read(files))
print('serial_read: {0:.4f} seconds averaged over {1} runs'
.format(timer.timeit(number=timeit_runs) / timeit_runs,
timeit_runs))
# 1 loops, best of 3: 7.04 s per loop

timer = timeit.Timer(lambda: parallel_read(files))
print('parallel_read: {0:.4f} seconds averaged over {1} runs'
.format(timer.timeit(number=timeit_runs) / timeit_runs,
timeit_runs))
# 1 loops, best of 3: 3.56 s per loop

# Examples of use:
x = serial_read(files)
print('len(x): {0}'.format(len(x))) # len(x): 100
print('len(x[0]): {0}'.format(len(x[0]))) # len(x[0]): 10000
print('len(x[0][0]): {0}'.format(len(x[0][0]))) # len(x[0]): 10000
print('x[0][0]: {0}'.format(x[0][0])) # len(x[0]): 10000
print('x[0].nbytes: {0} MB'.format(x[0].nbytes / 1e6)) # 4.0 MB

关于python - 以多线程方式加载多个 npz 文件,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/35328085/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com