gpt4 book ai didi

python - 受监控的培训类(class)如何运作?

转载 作者:IT老高 更新时间:2023-10-28 22:00:57 24 4
gpt4 key购买 nike

我试图了解使用 tf.Sessiontf.train.MonitoredTrainingSession 之间的区别,以及我可能更喜欢其中一个。似乎当我使用后者时,我可以避免许多“琐事”,例如初始化变量、启动队列运行器或设置文件编写器以进行汇总操作。另一方面,在受监控的培训类(class)中,我无法明确指定要使用的计算图。这一切对我来说似乎相当神秘。这些类的创建方式背后是否有一些我不理解的基本哲学?

最佳答案

我无法就这些类的创建方式提供一些见解,但这里有一些我认为与您如何使用它们相关的事情。

tf.Session 是 python TensorFlow API 中的一个低级对象,而,正如您所说,tf.train.MonitoredTrainingSession 具有许多方便的功能,在大多数常见情况下尤其有用。

在描述 tf.train.MonitoredTrainingSession 的一些好处之前,让我回答一下关于 session 使用的图表的问题。您可以通过使用上下文管理器 with your_graph.as_default() 来指定 MonitoredTrainingSession 使用的 tf.Graph:

from __future__ import print_function
import tensorflow as tf

def example():
g1 = tf.Graph()
with g1.as_default():
# Define operations and tensors in `g`.
c1 = tf.constant(42)
assert c1.graph is g1

g2 = tf.Graph()
with g2.as_default():
# Define operations and tensors in `g`.
c2 = tf.constant(3.14)
assert c2.graph is g2

# MonitoredTrainingSession example
with g1.as_default():
with tf.train.MonitoredTrainingSession() as sess:
print(c1.eval(session=sess))
# Next line raises
# ValueError: Cannot use the given session to evaluate tensor:
# the tensor's graph is different from the session's graph.
try:
print(c2.eval(session=sess))
except ValueError as e:
print(e)

# Session example
with tf.Session(graph=g2) as sess:
print(c2.eval(session=sess))
# Next line raises
# ValueError: Cannot use the given session to evaluate tensor:
# the tensor's graph is different from the session's graph.
try:
print(c1.eval(session=sess))
except ValueError as e:
print(e)

if __name__ == '__main__':
example()

所以,正如你所说,使用 MonitoredTrainingSession 的好处是,这个对象负责

  • 初始化变量,
  • 启动队列运行器以及
  • 设置文件编写器,

但它还具有使您的代码易于分发的好处,因为它的工作方式也不同,具体取决于您是否将正在运行的进程指定为主进程。

例如,您可以运行类似:

def run_my_model(train_op, session_args):
with tf.train.MonitoredTrainingSession(**session_args) as sess:
sess.run(train_op)

您将以非分布式方式调用:

run_my_model(train_op, {})`

或以分布式方式(有关输入的更多信息,请参阅 distributed doc):

run_my_model(train_op, {"master": server.target,
"is_chief": (FLAGS.task_index == 0)})

另一方面,使用原始 tf.Session 对象的好处是,您没有 tf.train.MonitoredTrainingSession 的额外好处,如果您不打算使用它们或想要获得更多控制权(例如队列的启动方式),这可能会很有用。

编辑(根据评论):对于操作初始化,您必须执行类似 (cf. official doc :

# Define your graph and your ops
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_p)
sess.run(your_graph_ops,...)

对于 QueueRunner,我建议您引用 official doc您可以在其中找到更完整的示例。

EDIT2:

了解 tf.train.MonitoredTrainingSession 工作原理的主要概念是 _WrappedSession 类:

This wrapper is used as a base class for various session wrappers that provide additional functionality such as monitoring, coordination, and recovery.

tf.train.MonitoredTrainingSession 以这种方式工作(从 version 1.1 开始):

  • 它首先检查它是主管还是 worker (参见 distributed doc 的词汇问题)。
  • 它开始已经提供的钩子(Hook)(例如,StopAtStepHook 在这个阶段只会检索 global_step 张量。
  • 它创建一个 session ,该 session 是一个 Chief(或 Worker session ),该 session 被包装在一个 _HookedSession 中,该 _CoordinatedSession 包装在一个 _CoordinatedSession 包装成 _RecoverableSession
    Chief/Worker session 负责运行 Scaffold 提供的初始化操作。
      scaffold: A `Scaffold` used for gathering or building supportive ops. If
    not specified a default one is created. It's used to finalize the graph.
  • chief session 还负责所有检查点部分:例如使用 Scaffold 中的 Saver 从检查点恢复。
  • _HookedSession 基本上是用来装饰 run 方法的:它调用 _call_hook_before_runafter_run 方法时相关的。
  • 在创建时,_CoordinatedSession 会构建一个 Coordinator,它会启动队列运行器并负责关闭它们。
  • _RecoverableSession 将确保在 tf.errors.AbortedError 的情况下重试。

总之,tf.train.MonitoredTrainingSession 避免了很多样板代码,同时可以通过钩子(Hook)机制轻松扩展。

关于python - 受监控的培训类(class)如何运作?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43245231/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com