gpt4 book ai didi

tensorflow - 在 TensorFlow 中,如何查看批量归一化参数?

转载 作者:行者123 更新时间:2023-11-30 08:39:34 24 4
gpt4 key购买 nike

我在网络中使用tf.layers.batch_normalization层。如您所知,批量归一化对该层中的每个单元 u_i 使用可训练参数 gamma 和 beta,为各种输入 x 选择其自己的标准差和 u_i(x) 平均值。通常,gamma 初始化为 1,beta 初始化为 0。

我有兴趣查看各个单元正在学习的 gamma 和 beta 值,以收集有关它们在网络训练后最终趋于何处的统计数据。我如何在每个训练实例中查看它们的当前值?

最佳答案

您可以获取批标准化层范围内的所有变量并打印它们。示例:

import tensorflow as tf

tf.reset_default_graph()
x = tf.constant(3.0, shape=(3,))
x = tf.layers.batch_normalization(x)

print(x.name) # batch_normalization/batchnorm/add_1:0

variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
scope='batch_normalization')
print(variables)

#[<tf.Variable 'batch_normalization/gamma:0' shape=(3,) dtype=float32_ref>,
# <tf.Variable 'batch_normalization/beta:0' shape=(3,) dtype=float32_ref>,
# <tf.Variable 'batch_normalization/moving_mean:0' shape=(3,) dtype=float32_ref>,
# <tf.Variable 'batch_normalization/moving_variance:0' shape=(3,) dtype=float32_ref>]

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
gamma = sess.run(variables[0])
print(gamma) # [1. 1. 1.]

关于tensorflow - 在 TensorFlow 中,如何查看批量归一化参数?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53282439/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com