gpt4 book ai didi

tensorflow - tensorflow 中的批量标准化

转载 作者:行者123 更新时间:2023-12-04 17:32:57 29 4
gpt4 key购买 nike

我注意到 tensorflow 的 api 中已经有批量标准化函数。我不明白的一件事是如何更改训练和测试之间的程序?

批量归一化在测试和训练期间的作用不同。具体来说,在训练期间使用固定的均值和方差。

某处有一些很好的示例代码吗?我看到了一些,但是对于范围变量,它变得令人困惑

最佳答案

你说得对,tf.nn.batch_normalization仅提供实现批量标准化的基本功能。您必须添加额外的逻辑以在训练期间跟踪移动均值和方差,并在推理期间使用经过训练的均值和方差。你可以看看这个example对于非常通用的实现,但不使用 gamma 的快速版本在这儿 :

  beta = tf.Variable(tf.zeros(shape), name='beta')
moving_mean = tf.Variable(tf.zeros(shape), name='moving_mean',
trainable=False)
moving_variance = tf.Variable(tf.ones(shape),
name='moving_variance',
trainable=False)
control_inputs = []
if is_training:
mean, variance = tf.nn.moments(image, [0, 1, 2])
update_moving_mean = moving_averages.assign_moving_average(
moving_mean, mean, self.decay)
update_moving_variance = moving_averages.assign_moving_average(
moving_variance, variance, self.decay)
control_inputs = [update_moving_mean, update_moving_variance]
else:
mean = moving_mean
variance = moving_variance
with tf.control_dependencies(control_inputs):
return tf.nn.batch_normalization(
image, mean=mean, variance=variance, offset=beta,
scale=None, variance_epsilon=0.001)

关于tensorflow - tensorflow 中的批量标准化,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/36978798/

29 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com