gpt4 book ai didi

tensorflow - 在 Tensorflow 中恢复检查点时如何获取 global_step?

转载 作者:行者123 更新时间:2023-12-03 13:23:52 24 4
gpt4 key购买 nike

我正在保存我的 session 状态,如下所示:

self._saver = tf.saver()
self._saver.save(self._session, '/network', global_step=self._time)

当我稍后恢复时,我想获取我从中恢复的检查点的 global_step 值。这是为了从中设置一些超参数。

执行此操作的 hacky 方法是运行并解析检查点目录中的文件名。但是必须有一个更好的、内置的方法来做到这一点?

最佳答案

这有点骇人听闻,但其他答案对我根本不起作用

ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 

#Extract from checkpoint filename
step = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1])

2017 年 9 月更新

我不确定这是否由于更新而开始工作,但以下方法似乎可以有效地让 global_step 正确更新和加载:

创建两个操作。一个用来保存 global_step,另一个用来增加它:
    global_step = tf.Variable(0, trainable=False, name='global_step')
increment_global_step = tf.assign_add(global_step,1,
name = 'increment_global_step')

现在在您的训练循环中,每次运行训练操作时都运行增量操作。
sess.run([train_op,increment_global_step],feed_dict=feed_dict)

如果您想在任何时候将全局步长值作为整数检索,只需在加载模型后使用以下命令:
sess.run(global_step)

这对于创建文件名或计算您当前的时期很有用,而无需第二个 tensorflow 变量来保存该值。例如,在加载时计算当前纪元将类似于:
loaded_epoch = sess.run(global_step)//(batch_size*num_train_records)

关于tensorflow - 在 Tensorflow 中恢复检查点时如何获取 global_step?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/36113090/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com