gpt4 book ai didi

python - tensorflow 在图执行期间添加摘要

转载 作者:行者123 更新时间:2023-11-30 09:19:56 25 4
gpt4 key购买 nike

我想在tensorflow教程给出的cifar10示例中输出每个固定步数的准确性,我尝试使用tf.summary.scalar(..) 在产生错误的钩子(Hook)中:图形已完成。但是,我认为我只能访问钩子(Hook)中的步骤数(我正在使用 cifar10_eval.py 评估准确性,这也是 tensorflow 教程给出的示例代码)。我还尝试将 global_step 写入检查点,但不幸的是 MointeredTrainingSession 仅支持时间间隔(save_checkpoint_secs)而不是步间隔。有什么建议吗?

cifar10_train.py

def train():
"""Train CIFAR-10 for a number of steps."""
with tf.Graph().as_default():
global_step = tf.contrib.framework.get_or_create_global_step()

# Build a Graph that trains the model with one batch of examples and
# updates the model parameters.
train_op = cifar10.train(loss, global_step)

class _LoggerHook(tf.train.SessionRunHook):
"""Logs loss and runtime."""

def begin(self):
self._step = -1
self._start_time = time.time()

def before_run(self, run_context):
self._step += 1
return tf.train.SessionRunArgs(loss) # Asks for loss value.

def after_run(self, run_context, run_values):
<output some information>

with tf.train.MonitoredTrainingSession(
checkpoint_dir=FLAGS.train_dir,
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
tf.train.NanTensorHook(loss),
_LoggerHook()],
config=tf.ConfigProto(
log_device_placement=FLAGS.log_device_placement)) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)

最佳答案

首先,应该指出的是,Tensorflow 提供的 cifar10 教程在两个单独的 session 中运行训练和评估。当训练 session 保存检查点时,评估 session 将检索该检查点。然后加载参数并进行评估。您在此处粘贴的代码仅用于培训类(class)。

我的建议是,你应该明确你要写哪个摘要。因为培训和评估是两个不同的阶段。摘要作者有两位。通常,他们会为不同的摘要作者提供不同的路径。

根据您的需要,这里有一些针对您的项目的提示。

  • 您不应向检查点写入任何内容,其中包含大量模型参数。
  • 请使用摘要编写器或标准 I/O 来保证记录的准确性。
  • 尝试使用摘要编写器时出现错误,因为应在启动 session 之前添加摘要的所有元素(包括标量)。

我猜 Tensorflow 将摘要视为默认图的一部分。因此,您可能需要在运行 session 之前配置摘要编写器。

关于python - tensorflow 在图执行期间添加摘要,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43671435/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com