gpt4 book ai didi

Tensorflow cifar同步点

转载 作者:行者123 更新时间:2023-12-03 00:36:58 25 4
gpt4 key购买 nike

阅读https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10/cifar10_multi_gpu_train.py对于函数average_gradients,提供以下注释:请注意,此函数提供跨所有塔的同步点。函数average_gradients是阻塞调用吗? 同步点是什么意思?

我假设这是一个阻塞调用,因为为了计算梯度的平均值,每个梯度必须单独计算?但是等待所有单独梯度计算的阻塞代码在哪里?

最佳答案

average_gradients其本身并不是阻塞函数。它可能是另一个具有 tensorflow 操作的函数,并且这仍然是一个同步点。使它阻塞的原因是它使用参数 tower_grads 这取决于之前 for 循环中创建的所有图表。

基本上,这里发生的是训练图的创建。首先,在for循环中for i in xrange(FLAGS.num_gpus)创建了多个图形“线程”。每个看起来像这样:

计算损失 --> 计算梯度 --> 附加到 tower_grads

每个图形“线程”都通过 with tf.device('/gpu:%d' % i) 分配给不同的 GPU。并且每个都可以彼此独立运行(并且稍后将并行运行)。现在下次tower_grads在没有设备规范的情况下使用,它会在主设备上创建一个图形延续,将所有这些单独的图形“线程”绑定(bind)到一个线程中。 Tensorflow 将确保作为 tower_grads 创建的一部分的每个图“线程”在运行 average_gradients 内的图表之前已完成功能。因此稍后当 sess.run([train_op, loss])被调用,这将是图的同步点。

关于Tensorflow cifar同步点,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43766223/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com