gpt4 book ai didi

python - 具有多 GPU 设置的 tf.data.Iterator

转载 作者:行者123 更新时间:2023-12-01 09:04:16 25 4
gpt4 key购买 nike

我看过cifar10 multi-GPU implementation为我自己的 GPU 训练模型的并行化汲取灵感。

我的模型使用来自 TFRecords 的数据,这些数据通过 tf.data.Iterator 进行迭代。类(class)。因此,给定 2 个 GPU,我想做的是调用 iterator.get_next()在 CPU 上为每个 GPU 执行一次(例如两次),进行一些预处理、嵌入查找和其他与 CPU 相关的操作,然后将两批数据输入 GPU。

伪代码:

with tf.device('/cpu:0'):
batches = []
for gpu in multiple_gpus:
single_gpu_batch = cpu_function(iterator.get_next())
batches.append(single_gpu_batch)

....................

for gpu, batch in zip(multiple_gpus, batches):
with tf.device('/device:GPU:{}'.format(gpu.id):
single_gpu_loss = inference_and_loss(batch)
tower_losses.append(single_gpu_loss)
...........
...........

total_loss = average_loss(tower_losses)

问题是,如果从数据中只能提取 1 个或更少的示例,并且我调用 iterator.get_next()两次tf.errors.OutOfRange将引发异常,并且第一次调用 iterator.get_next() 的数据(实际上并没有失败,只是第二个)永远不会通过 GPU。

我考虑过将数据绘制在一个 iterator.get_next() 中稍后调用并拆分它,但是 tf.split批量大小的失败无法除以 GPU 数量。

在多 GPU 设置中实现迭代器消耗的正确方法是什么?

最佳答案

我认为第二个建议是最简单的方法。为了避免最后一批出现分割问题,可以在dataset.batch中使用drop_remainder选项;或者,如果您需要查看所有数据,那么一种可能的解决方案是根据绘制批处理的大小显式设置维度,以便拆分操作永远不会失败:

dataset = dataset.batch(batch_size * multiple_gpus)
iterator = dataset.make_one_shot_iterator()
batches = iterator.get_next()

split_dims = [0] * multiple_gpus
drawn_batch_size = tf.shape(batches)[0]

以贪婪的方式,即适合每个设备上的batch_size 张量,直到用完

#### Solution 1 [Greedy]: 
for i in range(multiple_gpus):
split_dims[i] = tf.maximum(0, tf.minimum(batch_size, drawn_batch_size))
drawn_batch_size -= batch_size

或者以更分散的方式确保每个设备至少获得一个样本(假设multiple_gpus<drawn_batch_size)

### Solution 2 [Spread]
drawn_batch_size -= - multiple_gpus
for i in range(multiple_gpus):
split_dims[i] = tf.maximum(0, tf.minimum(batch_size - 1, drawn_batch_size)) + 1
drawn_batch_size -= batch_size
## Split batches
batches = tf.split(batches, split_dims)

关于python - 具有多 GPU 设置的 tf.data.Iterator,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52198119/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com