gpt4 book ai didi

tensorflow - 关于使用 tf.train.shuffle_batch() 创建批处理

转载 作者:行者123 更新时间:2023-12-02 23:02:51 25 4
gpt4 key购买 nike

Tensorflow tutorial ,它给出了以下关于 tf.train.shuffle_batch() 的示例:

# Creates batches of 32 images and 32 labels.
image_batch, label_batch = tf.train.shuffle_batch(
[single_image, single_label],
batch_size=32,
num_threads=4,
capacity=50000,
min_after_dequeue=10000)

我不太清楚capacitymin_after_dequeue的含义。在本例中,分别设置为5000010000。这种设置的逻辑是什么,或者说是什么意思。如果输入有 200 个图像和 200 个标签,会发生什么?

最佳答案

tf.train.shuffle_batch()函数使用 tf.RandomShuffleQueue在内部累积一批 batch_size 元素,这些元素是从当前队列中的元素中均匀随机采样的。

许多训练算法,例如 TensorFlow 用于优化神经网络的基于随机梯度下降的算法,都依赖于从整个训练集中均匀随机地采样记录。然而,将整个训练集加载到内存中(以便从中采样)并不总是可行的,因此 tf.train.shuffle_batch() 提供了一个折衷方案:它用之间的值填充内部缓冲区min_after_dequeuecapacity 元素,并从该缓冲区中随机均匀采样。对于许多训练过程来说,这提高了模型的准确性并提供了足够的随机化。

min_after_dequeuecapacity 参数对训练性能有间接影响。设置较大的 min_after_dequeue 值会延迟训练的开始,因为 TensorFlow 在训练开始之前必须处理至少那么多的元素。容量是输入管道将消耗的内存量的上限:将其设置得太大可能会导致训练过程耗尽内存(并且可能开始交换,这将损害训练吞吐量)。

如果数据集只有 200 张图像,则可以轻松地将整个数据集加载到内存中。 tf.train.shuffle_batch() 效率非常低,因为它将每个图像和标签多次放入 tf.RandomShuffleQueue 中。在这种情况下,您可能会发现使用 tf.train.slice_input_producer() 执行以下操作更有效和 tf.train.batch() :

random_image, random_label = tf.train.slice_input_producer([all_images, all_labels],
shuffle=True)

image_batch, label_batch = tf.train.batch([random_image, random_label],
batch_size=32)

关于tensorflow - 关于使用 tf.train.shuffle_batch() 创建批处理,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39283605/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com