gpt4 book ai didi

TensorFlow:当 is_training = False 时,Batch Norm 会破坏网络

转载 作者:行者123 更新时间:2023-12-03 00:23:04 25 4
gpt4 key购买 nike

我尝试使用 TensorFlow-Slim 中的批标准化层,如下所示:

net = ...
net = slim.batch_norm(net, scale = True, is_training = self.isTraining,
updates_collections = None, decay = 0.9)
net = tf.nn.relu(net)
net = ...

我训练时使用:

self.optimizer = slim.learning.create_train_op(self.model.loss,
tf.train.MomentumOptimizer(learning_rate = self.learningRate,
momentum = 0.9, use_nesterov = True)

optimizer = self.sess.run([self.optimizer],
feed_dict={self.model.isTraining:True})

我加载保存的权重:

net = model.Model(sess,width,height,channels,weightDecay)

savedWeightsDir = './savedWeights/'
saver = tf.train.Saver(max_to_keep = 5)
checkpointStr = tf.train.latest_checkpoint(savedWeightsDir)
sess.run(tf.global_variables_initializer())
saver.restore(sess, checkpointStr)
global_step = tf.contrib.framework.get_or_create_global_step()

我推断:

inf = self.sess.run([self.softmax],
feed_dict = {self.imageBatch:imageBatch,self.isTraining:False})

当然,我遗漏了很多内容并解释了一些代码,但我认为这就是批处理规范所涉及的全部内容。奇怪的是,如果我设置 isTraining:True,我会得到更好的结果。是否可能是加载权重的原因 - 也许批处理标准值未保存?代码中是否有明显错误?谢谢。

最佳答案

我刚刚遇到了同样的问题并发现 solution here 。该问题源自于tf.layers.batch_normalization层,该层需要更新moving_meanmoving_variance

为了在您的情况下正确执行此操作,您需要将培训过程修改为:

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
self.optimizer = slim.learning.create_train_op(self.model.loss,
tf.train.MomentumOptimizer(learning_rate = self.learningRate,
momentum = 0.9, use_nesterov = True)

或者更一般地说,来自 documentation :

  x_norm = tf.layers.batch_normalization(x, training=training)

# ...

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss)

关于TensorFlow:当 is_training = False 时,Batch Norm 会破坏网络,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44211371/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com