gpt4 book ai didi

python - 如何在 Tensorflow 中使用 SWA 实现 Batch Norm?

转载 作者:行者123 更新时间:2023-12-05 06:14:31 37 4
gpt4 key购买 nike

我在 Tensorflow 2.2 中使用带有批量归一化层的随机权重平均 (SWA)。对于 Batch Norm,我使用 tf.keras.layers.BatchNormalization。对于 SWA,我使用我自己的代码来平均权重(我在 tfa.optimizers.SWA 出现之前编写了我的代码)。我在多个来源中读到,如果使用 batch norm 和 SWA,我们必须运行前向传递以使某些数据(激活权重和/或动量值的运行均值和 st dev?)可用于 batch norm 层。我不明白的是——尽管读了很多书——正是需要做什么以及如何做。具体来说:

  1. 什么时候必须运行前向/预测 channel ?在每个结束mini-batch,每个 epoch 结束,所有训练结束?
  2. 运行正向传播时,运行平均值和标准差值如何可用到批量规范层?
  3. 这个过程是由 tfa.optimizers.SWA 类神奇地执行的吗?

最佳答案

When must the forward/prediction pass be run? At the end of eachmini-batch, end of each epoch, end of all training?

训练结束。可以这样想,SWA 是通过将您的最终权重与运行平均值交换来执行的。但是所有批量归一化层仍然是根据旧权重的统计数据计算的。所以我们需要向前传球让他们追上。

When the forward pass is run, how are the running mean & stdev valuesmade available to the batch norm layers?

在正常的前向传递(预测)过程中,运行平均值和标准差不会更新。所以我们实际上需要做的是训练网络,而不是更新权重。这就是论文中所说的以“训练模式”运行前向传球时所指的内容。

实现这一目标的最简单方法(据我所知)是重置批量归一化层并训练一个额外的时期,并将学习率设置为 0。

Is this process performed magically by the tfa.optimizers.SWA class?

我不知道。但是,如果您使用的是 Tensorflow Keras,那么我已经制作了这个 Keras SWA callback这就像论文中的那样,包括学习率计划。

关于python - 如何在 Tensorflow 中使用 SWA 实现 Batch Norm?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62855224/

37 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com