gpt4 book ai didi

python - MonitoredTrainingSession 每次运行写入多个元图事件

转载 作者:太空狗 更新时间:2023-10-29 21:08:07 25 4
gpt4 key购买 nike

当使用 tf.train.MonitoredTrainingSession 编写检查点文件时,它会以某种方式写入多个元图。我做错了什么?

我将其简化为以下代码:

import tensorflow as tf
global_step = tf.Variable(0, dtype=tf.int32, trainable=False, name="global_step")
train = tf.assign(global_step, global_step + 1)
saver = tf.train.Saver()
hooks = [(tf.train.CheckpointSaverHook(checkpoint_dir=output_path + "test1/ckpt/",
save_steps = 10,
saver = saver))]

with tf.train.MonitoredTrainingSession(master = '',
is_chief = True,
checkpoint_dir = None,
hooks = hooks,
save_checkpoint_secs = None,
save_summaries_steps = None,
save_summaries_secs = None) as mon_sess:
for i in range(30):
if mon_sess.should_stop():
break
try:
gs, _ = mon_sess.run([global_step, train])
print(gs)
except (tf.errors.OutOfRangeError,tf.errors.CancelledError) as e:
break
finally:
pass

运行它会产生重复的元图,正如 tensorboard 警告所证明的那样:

$ tensorboard --logdir ../train/test1/ --port=6006

WARNING:tensorflow:Found more than one graph event per run, or there was a metagraph containing a graph_def, as well as one or more graph events. Overwriting the graph with the newest event. Starting TensorBoard 54 at local:6006 (Press CTRL+C to quit)

这是在 tensorflow 1.2.0 中(我无法升级)。

在没有监控 session 的情况下运行相同的东西会给出正确的检查点输出:

global_step = tf.Variable(0, dtype=tf.int32, trainable=False, name="global_step")
train = tf.assign(global_step, global_step + 1)
saver = tf.train.Saver()
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
for i in range(30):
gs, _ = sess.run([global_step, train])
print(gs)
if i%10==0:
saver.save(sess, output_path+'/test2/my-model', global_step=gs)
print("Saved ckpt")

没有张量板错误的结果:

$ tensorboard --logdir ../traitest2/ --port=6006

Starting TensorBoard 54 at local:6006 (Press CTRL+C to quit)

我想解决这个问题,因为我怀疑我遗漏了一些基本的东西,而且这个错误可能与我在分布式模式下遇到的其他问题有某种联系。每当我想更新数据时,我都必须重新启动 tensorboard。此外,TensorBoard 在发出许多此类警告时似乎随着时间的推移变得非常缓慢。

有一个相关的问题:tensorflow Found more than one graph event per run在这种情况下,错误是由于多次运行(使用不同的参数)写入同一输出目录。这里的案例是关于单次运行到一个干净的输出目录。

在分布式模式下运行 MonitoredTrainingSession 版本会出现相同的错误。

10 月 12 日更新

@Nikhil Kothari 建议使用 tf.train.MonitoredSession 而不是更大的 tf.train.MonitoredTrainSession 包装器,如下所示:

import tensorflow as tf
global_step = tf.Variable(0, dtype=tf.int32, trainable=False, name="global_step")
train = tf.assign(global_step, global_step + 1)
saver = tf.train.Saver()
hooks[(tf.train.CheckpointSaverHook(checkpoint_dir=output_path + "test3/ckpt/",
save_steps=10,
saver=saver))]

chiefsession = tf.train.ChiefSessionCreator(scaffold=None,
master='',
config=None,
checkpoint_dir=None,
checkpoint_filename_with_path=None)
with tf.train.MonitoredSession(session_creator=chiefsession,
hooks=hooks,
stop_grace_period_secs=120) as mon_sess:
for i in range(30):
if mon_sess.should_stop():
break
try:
gs, _ = mon_sess.run([global_step, train])
print(gs)
except (tf.errors.OutOfRangeError,tf.errors.CancelledError) as e:
break
finally:
pass

不幸的是,这仍然会给出相同的张量板错误:

$ tensorboard --logdir ../train/test3/ --port=6006

WARNING:tensorflow:Found more than one graph event per run, or there was a metagraph containing a graph_def, as well as one or more graph events. Overwriting the graph with the newest event. Starting TensorBoard 54 at local:6006 (Press CTRL+C to quit)

顺便说一句,每个代码块都是独立的,将其复制=粘贴到 Jupyter 笔记本中,您将重现问题。

最佳答案

我想知道这是否是因为您集群中的每个节点都在运行相同的代码,将自己声明为 chief,并保存图表和检查点。

如果 is_chief = True 在 Stack Overflow 上的帖子中只是说明性的,或者这正是您正在使用的,我不知道……所以在这里猜测一下。

我个人使用 MonitoredSession 而不是 MonitoredTrainingSession 并根据代码是否在 master/chief 上运行创建了一个 Hook 列表。示例:https://github.com/TensorLab/tensorfx/blob/master/src/training/_trainer.py#L94

关于python - MonitoredTrainingSession 每次运行写入多个元图事件,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46636558/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com