gpt4 book ai didi

tensorflow - tf.layers.batch_normalization 中 "trainable"和 "training"标志的意义

转载 作者:行者123 更新时间:2023-12-03 00:43:41 25 4
gpt4 key购买 nike

tf.layers.batch_normalization 中“可训练”和“训练”标志的意义是什么?这两者在训练和预测过程中有何不同?

最佳答案

批量归一化有两个阶段:

1. Training:
- Normalize layer activations using `moving_avg`, `moving_var`, `beta` and `gamma`
(`training`* should be `True`.)
- update the `moving_avg` and `moving_var` statistics.
(`trainable` should be `True`)
2. Inference:
- Normalize layer activations using `beta` and `gamma`.
(`training` should be `False`)

说明几种情况的示例代码:

#random image
img = np.random.randint(0,10,(2,2,4)).astype(np.float32)

# batch norm params initialized
beta = np.ones((4)).astype(np.float32)*1 # all ones
gamma = np.ones((4)).astype(np.float32)*2 # all twos
moving_mean = np.zeros((4)).astype(np.float32) # all zeros
moving_var = np.ones((4)).astype(np.float32) # all ones

#Placeholders for input image
_input = tf.placeholder(tf.float32, shape=(1,2,2,4), name='input')

#batch Norm
out = tf.layers.batch_normalization(
_input,
beta_initializer=tf.constant_initializer(beta),
gamma_initializer=tf.constant_initializer(gamma),
moving_mean_initializer=tf.constant_initializer(moving_mean),
moving_variance_initializer=tf.constant_initializer(moving_var),
training=False, trainable=False)


update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
init_op = tf.global_variables_initializer()

## 2. Run the graph in a session

with tf.Session() as sess:

# init the variables
sess.run(init_op)

for i in range(2):
ops, o = sess.run([update_ops, out], feed_dict={_input: np.expand_dims(img, 0)})
print('beta', sess.run('batch_normalization/beta:0'))
print('gamma', sess.run('batch_normalization/gamma:0'))
print('moving_avg',sess.run('batch_normalization/moving_mean:0'))
print('moving_variance',sess.run('batch_normalization/moving_variance:0'))
print('out', np.round(o))
print('')

training=Falsetrainable=False时:

  img = [[[4., 5., 9., 0.]...
out = [[ 9. 11. 19. 1.]...
The activation is scaled/shifted using gamma and beta.

training=Truetrainable=False时:

  out = [[ 2.  2.  3. -1.] ...
The activation is normalized using `moving_avg`, `moving_var`, `gamma` and `beta`.
The averages are not updated.

traning=Truetrainable=True 时:

  The out is same as above, but the `moving_avg` and `moving_var` gets updated to new values.

moving_avg [0.03249997 0.03499997 0.06499994 0.02749997]
moving_variance [1.0791667 1.1266665 1.0999999 1.0925]

关于tensorflow - tf.layers.batch_normalization 中 "trainable"和 "training"标志的意义,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50209310/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com