gpt4 book ai didi

python - Tensorflow `tf.layers.batch_normalization` 没有向 `tf.GraphKeys.UPDATE_OPS` 添加更新操作

转载 作者:太空狗 更新时间:2023-10-29 21:48:24 29 4
gpt4 key购买 nike

以下代码(复制/粘贴可运行)说明了如何使用 tf.layers.batch_normalization

import tensorflow as tf
bn = tf.layers.batch_normalization(tf.constant([0.0]))
print(tf.get_collection(tf.GraphKeys.UPDATE_OPS))

> [] # UPDATE_OPS collection is empty

使用 TF 1.5,文档(在下面引用)明确指出在这种情况下 UPDATE_OPS 不应为空 (https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization):

Note: when training, the moving_mean and moving_variance need to be updated. By default the update ops are placed in tf.GraphKeys.UPDATE_OPS, so they need to be added as a dependency to the train_op. For example:

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss)

最佳答案

只需将代码更改为训练模式(通过将 training 标志设置为 True),如 quote 中所述:

Note: when training, the moving_mean and moving_variance need to be updated. By default the update ops are placed in tf.GraphKeys.UPDATE_OPS, so they need to be added as a dependency to the train_op.

 import tensorflow as tf
bn = tf.layers.batch_normalization(tf.constant([0.0]), training=True)
print(tf.get_collection(tf.GraphKeys.UPDATE_OPS))

将输出:

[< tf.Tensor 'batch_normalization/AssignMovingAvg:0' shape=(1,) dtype=float32_ref>, 
< tf.Tensor 'batch_normalization/AssignMovingAvg_1:0' shape=(1,) dtype=float32_ref>]

Gamma 和 Beta 最终出现在 TRAINABLE_VARIABLES 集合中:

print(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))

[<tf.Variable 'batch_normalization/gamma:0' shape=(1,) dtype=float32_ref>,
<tf.Variable 'batch_normalization/beta:0' shape=(1,) dtype=float32_ref>]

关于python - Tensorflow `tf.layers.batch_normalization` 没有向 `tf.GraphKeys.UPDATE_OPS` 添加更新操作,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48874558/

29 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com