gpt4 book ai didi

python - 如何在 Tensorflow 中实现以数据集形式存储在磁盘上的 numpy 数组的多线程导入

转载 作者:行者123 更新时间:2023-12-01 01:55:44 31 4
gpt4 key购买 nike

我的数据集的输入和标签分别存储在 10000 个 .npy 文件中。例如inputs/0000.npy,...inputs/9999.npylabels/0000.npy,...labels/9999.npy。虽然每个文件可以独立存储在内存中,但 20k 数组的整个数据集不能存储在内存中。我想实现多线程 CPU 管道以批量导入数据集,例如 batch_size=8

我尝试实现新的 Tensorflow 数据 API 中提到的功能,但没有找到任何满足我要求的示例。所有示例似乎都是针对整个数据集可以加载到 RAM 中的情况。知道如何解决这个问题吗?

最佳答案

我会使用tf.data.Dataset.from_generator()它允许您通过自定义 python 生成器函数使用 Tensorflow 数据 API。这样,您就可以迭代加载每个 .npy 文件,一次只将一个 numpy.ndarray 加载到内存中。假设每个加载的 numpy.ndarray 都是单个实例,您的案例的示例代码可能如下所示:

import tensorflow as tf
import numpy as np
import os


def gen():
inputs_path = ""
labels_path = ""
for input_file, label_file in zip(os.listdir(inputs_path), os.listdir(labels_path)):
x = np.load(os.path.join(inputs_path, input_file))
y = np.load(os.path.join(labels_path, label_file))
yield x, y


INPUT_SHAPE = []
LABEL_SHAPE = []

# Input pipeline
ds = tf.data.Dataset.from_generator(
gen, (tf.float32, tf.int64), (tf.TensorShape(INPUT_SHAPE), tf.TensorShape(LABEL_SHAPE)))
ds = ds.batch(8)
ds_iter = ds.make_initializable_iterator()
inputs_batch, labels_batch = ds_iter.get_next()

我还没有测试过代码。希望对您有帮助!

关于python - 如何在 Tensorflow 中实现以数据集形式存储在磁盘上的 numpy 数组的多线程导入,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50216747/

31 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com