gpt4 book ai didi

tensorflow - 如何使用 Tensorflow 的 Dataset API 将数据移动到多个 GPU 塔

转载 作者:行者123 更新时间:2023-12-04 00:43:47 32 4
gpt4 key购买 nike

我们正在 Tensorflow 上运行多 GPU 作业,并评估从基于队列的模型(使用 string_input_producer 接口(interface))到新的 Tensorflow 数据集 API 的迁移。后者似乎提供了一种更简单的方式来同时在训练和验证之间切换。

下面的一段代码显示了我们是如何做到这一点的。

    train_dataset, train_iterator = get_dataset(train_files, batch_size, epochs)
val_dataset, val_iterator = get_dataset(val_files, batch_size, epochs)


is_validating = tf.placeholder(dtype=bool, shape=())
next_batch = tf.cond(is_validating,
lambda: val_iterator.get_next(),
lambda: train_iterator.get_next())

validation_tower = self.num_gpus - 1
tower_grads = []

for i in range(self.num_gpus):
with tf.variable_scope(tf.get_variable_scope(),reuse=(i > 0)):
with tf.device('/gpu:%d' % i), tf.name_scope('%s_%d' % ('gpu_', i)) as scope:
if i == validation_tower:
images, labels = next_batch
# Loss funcs snipped out
else:
images, labels = next_batch
# Loss funcs snipped out

get_dataset 函数构建数据集,设置映射函数和批量大小。它还构建了一个迭代器,但不初始化它。迭代器的初始化发生在 session 开始之前。

is_validating bool 值在 session 运行时提供,每隔几步我们通过 feed_dict 将 is_validating 传递为 True 以使用验证数据集

我的问题是:

假设我有 8 个 GPU,所以我们在 7 个 GPU 上运行训练。对于这 7 个 GPU 中的每一个,迭代器是否从同一点前进,从而为所有 7 个 GPU 提供相同的数据?

最佳答案

目前有三个主要选项,它们具有不同的可用性和性能权衡:

  • Dataset.batch() 转换,创建一个包含所有 GPU 示例的大批量。然后使用 tf.split(..., self.num_gpus) Iterator.get_next() 的输出为每个 GPU 创建子批处理。这可能是最简单的方法,但它确实将拆分置于关键路径上。
  • Dataset.batch()转换,创建一个适合单个 GPU 大小的小批量。然后调用Iterator.get_next()每个 GPU 一次以获得多个不同的批处理。 (相比之下,在您当前的代码中,next_batch 的相同值被发送到每个 GPU,这可能不是您想要发生的。)
  • 创建多个迭代器,每个 GPU 一个。使用 Dataset.shard() 对数据进行分片在管道的早期(例如,如果您的数据集被分片,则在文件列表中)。请注意,这种方法会消耗主机上的更多资源,因此您可能需要调低任何缓冲区大小和/或并行度

  • 请注意,当前 tf.data流水线仅在 CPU 上运行,高效流水线的一个重要方面是在上一步仍在运行时将您的训练输入暂存到 GPU。见 TensorFlow CNN benchmarks例如,展示如何有效地将数据暂存到 GPU 的代码。我们目前正在努力将此支持添加到 tf.data直接API。

    关于tensorflow - 如何使用 Tensorflow 的 Dataset API 将数据移动到多个 GPU 塔,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46965098/

    32 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com