gpt4 book ai didi

python - TensorFlow 'global_step' 变量未针对指数衰减进行更新

转载 作者:行者123 更新时间:2023-12-01 13:36:57 25 4
gpt4 key购买 nike

我有一个网络,我在其中使用学习率的指数衰减。为此,我正在跟踪一个“global_step”TF 变量,该变量在处理的每个批处理中递增 1。然而,看起来实际上,它并没有真正得到更新。这是代码。

...
global_step = tf.Variable(0, trainable=False, name='global_step')
starter_learning_rate = 0.01
learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step, 1000, 0.50)

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
optm = tf.train.AdamOptimizer(learning_rate).minimize(cost, global_step=global_step)
init = tf.global_variables_initializer()


def train(file):
global global_step
for batch in batches:
global_step += 1
...
return loss

...
global_step = 0
for epoch in EPOCHS:
for f in files:
loss = train(f)

函数内部和外部的 global_step 正在更新。但是我的学习率没有改变。当我将摘要附加到我的 TF global_step 变量时,我看到它始终保持为 0。

这里有什么问题?

最佳答案

实际上,我还没有看到你在哪里设置learning_rate变量,但这是如何使用它的方式:

定义全局步骤变量

global_step = tf.Variable(0)

定义学习率随不同params变化的方式

learning_rate = tf.train.exponential_decay(0.1, global_step, 500, 0.7, staircase=True)

将它们传递给优化器

optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)

关于python - TensorFlow 'global_step' 变量未针对指数衰减进行更新,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42755157/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com