gpt4 book ai didi

python - 使用 Tensorflow batch_norm 函数获得低测试精度

转载 作者:太空狗 更新时间:2023-10-30 02:27:39 25 4
gpt4 key购买 nike

我在 MNIST 数据上使用 Tensorflow 的官方批量归一化 (BN) 函数 ( tf.contrib.layers.batch_norm() )。我使用以下代码添加 BN:

local4_bn = tf.contrib.layers.batch_norm(local4, is_training=True)

在测试期间,我在上面的代码行中更改了“is_training=False”,结果发现只有 20% 的准确率。但是,如果我将上述代码也用于测试(即保持 is_training=True)且批处理大小为 100 张图像,则准确率约为 99%。此观察结果表明 batch_norm() 计算的指数移动平均值和方差可能不正确,或者我的代码中遗漏了一些东西。

谁能回答一下上述问题的解决方案。

最佳答案

当您使用 is_training=True 测试您的模型时,您可以获得约 99% 的准确率,这仅仅是因为批量大小为 100。如果您将批量大小更改为 1,您的准确性将会降低。

这是因为您要计算输入批处理的指数移动平均值和方差,而不是使用这些值(批处理)归一化层输出。

batch_norm 函数具有参数 variables_collections,可帮助您存储训练阶段计算的移动平均值和方差,并在测试阶段重用它们。

如果您为这些变量定义一个集合,那么 batch_norm 层将在测试阶段使用它们,而不是计算新值。

因此,如果您将批量归一化层定义更改为

local4_bn = tf.contrib.layers.batch_norm(local4, is_training=True, variables_collections=["batch_norm_non_trainable_variables_collection"])

该层会将计算出的变量存储到 "batch_norm_non_trainable_variables_collection" 集合中。

在测试阶段,当您传递 is_training=False 参数时,该层将重新使用它在集合中找到的计算值。

请注意,移动平均值和方差不是可训练参数,因此,如果您仅将模型可训练参数保存在检查点文件中,则必须手动将存储的不可训练变量添加到先前定义的集合中。

您可以在创建 Saver 对象时执行此操作:

saver = tf.train.Saver(tf.get_trainable_variables() + tf.get_collection_ref("batch_norm_non_trainable_variables_co‌​llection") + otherlistofvariables)

在上瘾中,由于批量归一化可以限制所应用层的表达能力(因为它限制了值的范围),你应该让网络学习参数 gammabeta(paper 中描述的仿射变换系数)允许网络学习,因此,仿射变换增加了层的表示能力。

您可以通过将 batch_norm 函数的参数设置为 True 来启用这些参数的学习,方法如下:

local4_bn = tf.contrib.layers.batch_norm(
local4,
is_training=True,
center=True, # beta
scale=True, # gamma
variables_collections=["batch_norm_non_trainable_variables_collection"])

关于python - 使用 Tensorflow batch_norm 函数获得低测试精度,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/40081697/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com