gpt4 book ai didi

TensorFlow 将数据加载到 tf.Dataset 所需的时间太长

转载 作者:行者123 更新时间:2023-12-02 20:19:05 27 4
gpt4 key购买 nike

我正在使用 TensorFlow 1.9 来训练图像数据集,该数据集太大,无法从我的硬​​盘加载到 RAM 中。因此,我将硬盘上的数据集分成两半。我想知道在整个数据集上训练最有效的方法是什么。

我的 GPU 具有 3 GB 内存,我的 RAM 具有 32 GB 内存。每半数据集的大小为 20 GB。我的硬盘有足够的可用空间(超过 1 TB)。

我的尝试如下。我创建了一个可初始化的 tf.Dataset,然后在每个时期,我将其初始化两次:为数据集的每一半初始化一次。这样,每个 epoch 都会看到整个数据集,但每次只需将其中一半加载到 RAM 中。

但是,这非常慢,因为从硬盘加载数据需要很长时间,而且每次用这些数据初始化数据集也需要很长时间。

有更有效的方法吗?

在加载数据集的另一半之前,我尝试对数据集的每一半进行多个时期的训练,这要快得多,但这会导致验证数据的性能更差。据推测,这是因为模型在每一半上都过度拟合,然后无法推广到另一半的数据。

在下面的代码中,我创建并保存了一些测试数据,然后按上述方式加载这些数据。加载每半个数据集的时间约为 5 秒,使用该数据初始化数据集的时间约为 1 秒。这可能看起来只是一小部分,但它是在多个时期内累积起来的。事实上,我的计算机加载数据所花费的时间几乎与实际训练数据所花费的时间一样多。

import tensorflow as tf
import numpy as np
import time

# Create and save 2 datasets of test NumPy data
dataset_num_elements = 100000
element_dim = 10000
batch_size = 50
test_data = np.zeros([2, int(dataset_num_elements * 0.5), element_dim], dtype=np.float32)
np.savez('test_data_1.npz', x=test_data[0])
np.savez('test_data_2.npz', x=test_data[1])

# Create the TensorFlow dataset
data_placeholder = tf.placeholder(tf.float32, [int(dataset_num_elements * 0.5), element_dim])
dataset = tf.data.Dataset.from_tensor_slices(data_placeholder)
dataset = dataset.shuffle(buffer_size=dataset_num_elements)
dataset = dataset.repeat()
dataset = dataset.batch(batch_size=batch_size)
dataset = dataset.prefetch(1)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
init_op = iterator.initializer

num_batches = int(dataset_num_elements / batch_size)

with tf.Session() as sess:
while True:
for dataset_section in range(2):
# Load the data from the hard drive
t1 = time.time()
print('Loading')
loaded_data = np.load('test_data_' + str(dataset_section + 1) + '.npz')
x = loaded_data['x']
print('Loaded')
t2 = time.time()
loading_time = t2 - t1
print('Loading time = ' + str(loading_time))
# Initialize the dataset with this loaded data
t1 = time.time()
sess.run(init_op, feed_dict={data_placeholder: x})
t2 = time.time()
initialization_time = t2 - t1
print('Initialization time = ' + str(initialization_time))
# Read the data in batches
for i in range(num_batches):
x = sess.run(next_element)

最佳答案

Feed 并不是输入数据的有效方式。您可以像这样输入数据:

  1. 创建包含所有输入文件名的文件名数据集。您可以在此处随机播放、重复数据集。
  2. 将数据集映射到数据,映射功能是读取、解码、变换图像。使用多线程进行 map 转换。
  3. 预取要训练的数据。

这只是一个示例方法。您可以设计自己的管道,请记住以下几点:

  • 尽可能使用轻质饲料
  • 使用多线程读取和预处理
  • 预取数据进行训练

关于TensorFlow 将数据加载到 tf.Dataset 所需的时间太长,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51813951/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com