gpt4 book ai didi

python - 如果未设置 tf.stop_gradient 会怎样?

转载 作者:行者123 更新时间:2023-11-28 18:00:23 24 4
gpt4 key购买 nike

我正在阅读 tensorflow 模型的 faster-rcnn 代码。我对 tf.stop_gradient 的使用感到困惑。

考虑以下代码片段:

if self._is_training:
proposal_boxes = tf.stop_gradient(proposal_boxes)
if not self._hard_example_miner:
(groundtruth_boxlists, groundtruth_classes_with_background_list, _,
groundtruth_weights_list
) = self._format_groundtruth_data(true_image_shapes)
(proposal_boxes, proposal_scores,
num_proposals) = self._sample_box_classifier_batch(
proposal_boxes, proposal_scores, num_proposals,
groundtruth_boxlists, groundtruth_classes_with_background_list,
groundtruth_weights_list)

更多代码为here .我的问题是:如果未为 proposal_boxes 设置 tf.stop_gradient 会怎样?

最佳答案

这真是个好问题,因为这条简单的tf.stop_gradient 行在训练faster_rcnn 模型时非常关键。这就是为什么在培训期间需要它。

Faster_rcnn 模型是两阶段检测器,损失函数必须满足两个阶段的目标。在 faster_rcnn 中,rpn 损失和 fast_rcnn 损失都需要最小化。

这是论文第 3.2 节的内容

Both RPN and Fast R-CNN, trained independently will modify their convlolutional layers in different ways. We therefore need to develop a technique that allows for sharing convolutional layers between the two networks, rather than learning two separate networks.

然后论文描述了三种训练方案,在原论文中他们采用了第一种方案——Alternating training,即先训练RPN再训练Fast-RCNN。

第二种方案是近似联合训练,实现简单,API采用该方案。 Fast R-CNN 接受来自预测边界框的输入坐标(通过 rpn),因此 Fast R-CNN 损失将具有 w.r.t 边界框坐标的梯度。但在这个训练方案中,这些梯度被忽略,这正是使用tf.stop_gradient的原因。论文报道,这种训练方案将减少训练时间25-50%。

第三种方案是非近似联合训练,所以不需要tf.stop_gradient。该论文报告说,拥有一个可微分 w.r.t 框坐标的 RoI 池化层是一个非常重要的问题。

但是为什么这些梯度被忽略了呢?

事实证明,RoI 池化层是完全可微的,但支持方案二的主要原因是方案三会导致它在训练早期不稳定。

API 的一位作者给出了非常好的答案 here

一些 further reading关于近似联合训练。

关于python - 如果未设置 tf.stop_gradient 会怎样?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56059078/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com