gpt4 book ai didi

python - 如何在 Tensorflow 中使用指数移动平均线

转载 作者:太空狗 更新时间:2023-10-30 00:34:12 25 4
gpt4 key购买 nike

问题

Tensorflow 包含函数 tf.train.ExponentialMovingAverage这使我们能够对参数应用移动平均线,我发现这对于稳定模型测试非常有用。

话虽如此,我发现将其应用于一般模型有点令人恼火。到目前为止,我最成功的方法(如下所示)是编写一个函数装饰器,然后将我的整个神经网络放入一个函数中。

但这也有一些缺点。首先,它复制了整个图,其次,我需要在一个函数中定义我的神经网络。

有更好的方法吗?

当前实现

def ema_wrapper(is_training, decay=0.99):
"""Use Exponential Moving Average of parameters during testing.

Parameters
----------
is_training : bool or `tf.Tensor` of type bool
EMA is applied if ``is_training`` is False.
decay:
Decay rate for `tf.train.ExponentialMovingAverage`
"""
def function(fun):
@functools.wraps(fun)
def fun_wrapper(*args, **kwargs):
# Regular call
with tf.variable_scope('ema_wrapper', reuse=False) as scope:
result_train = fun(*args, **kwargs)

# Set up exponential moving average
ema = tf.train.ExponentialMovingAverage(decay=decay)
var_class = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
scope.name)
ema_op = ema.apply(var_class)

# Add to collection so they are updated
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, ema_op)

# Getter for the variables with EMA applied
def ema_getter(getter, name, *args, **kwargs):
var = getter(name, *args, **kwargs)
ema_var = ema.average(var)
return ema_var if ema_var else var

# Call with EMA applied
with tf.variable_scope('ema_wrapper', reuse=True,
custom_getter=ema_getter):
result_test = fun(*args, **kwargs)

# Return the correct version depending on if we're training or not
return tf.cond(is_training,
lambda: result_train, lambda: result_test)
return fun_wrapper
return function

示例用法:

@ema_wrapper(is_training)
def neural_network(x):
# If is_training is False, we will use an EMA of a instead
a = tf.get_variable('a', [], tf.float32)
return a * x

最佳答案

您可以拥有一个将值从 EMA 变量传输到原始变量的操作:

import tensorflow as tf

# Make model...
minimize_op = ...
model_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
# Make EMA object and update interal variables after optimization step
ema = tf.train.ExponentialMovingAverage(decay=decay)
with tf.control_dependencies([minimize_op]):
train_op = ema.apply(model_vars)

# Transfer EMA values to original variables
retrieve_ema_weights_op = tf.group(
[tf.assign(var, ema.average(var)) for var in model_vars])

with tf.Session() as sess:
# Do training
while ...:
sess.run(train_op, ...)
# Copy EMA values to weights
sess.run(retrieve_ema_weights_op)
# Test model with EMA weights
# ...

编辑:

我制作了一个更长的版本,能够通过可变备份在训练和测试模式之间切换:

import tensorflow as tf

# Make model...
minimize_op = ...
model_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

is_training = tf.get_variable('is_training', shape=(), dtype=tf.bool,
initializer=tf.constant_initializer(True, dtype=tf.bool))

# Make EMA object and update internal variables after optimization step
ema = tf.train.ExponentialMovingAverage(decay=decay)
with tf.control_dependencies([minimize_op]):
train_op = ema.apply(model_vars)
# Make backup variables
with tf.variable_scope('BackupVariables'):
backup_vars = [tf.get_variable(var.op.name, dtype=var.value().dtype, trainable=False,
initializer=var.initialized_value())
for var in model_vars]

def ema_to_weights():
return tf.group(*(tf.assign(var, ema.average(var).read_value())
for var in model_vars))
def save_weight_backups():
return tf.group(*(tf.assign(bck, var.read_value())
for var, bck in zip(model_vars, backup_vars)))
def restore_weight_backups():
return tf.group(*(tf.assign(var, bck.read_value())
for var, bck in zip(model_vars, backup_vars)))

def to_training():
with tf.control_dependencies([tf.assign(is_training, True)]):
return restore_weight_backups()

def to_testing():
with tf.control_dependencies([tf.assign(is_training, False)]):
with tf.control_dependencies([save_weight_backups()]):
return ema_to_weights()

switch_to_train_mode_op = tf.cond(is_training, lambda: tf.group(), to_training)
switch_to_test_mode_op = tf.cond(is_training, to_testing, lambda: tf.group())

init_op = tf.global_variables_initializer()

with tf.Session() as sess:
sess.run(init_op)
# Unnecessary, since it begins in training mode, but unharmful
sess.run(switch_to_train_mode_op)
# Do training
while ...:
sess.run(train_op, ...)
# To test mode
sess.run(switch_to_test_mode_op)
# Switching multiple times should not overwrite backups
sess.run(switch_to_test_mode_op)
# Test model with EMA weights
# ...
# Back to training mode
sess.run(switch_to_train_mode_op)
# Keep training...

关于python - 如何在 Tensorflow 中使用指数移动平均线,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49147961/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com