gpt4 book ai didi

tensorflow - 如何使用 tf.train.MonitoredTrainingSession 仅恢复某些变量

转载 作者:行者123 更新时间:2023-12-04 01:49:59 27 4
gpt4 key购买 nike

如何告诉 tf.train.MonitoredTrainingSession 仅恢复变量的一个子集,并对其余变量执行初始化?

从 cifar10 教程开始..
https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10/cifar10_train.py

.. 我创建了要恢复和初始化的变量列表,并使用传递给 MonitoredTrainingSession 的 Scaffold 指定它们:

  restoration_saver = Saver(var_list=restore_vars)
restoration_scaffold = Scaffold(init_op=variables_initializer(init_vars),
ready_op=constant([]),
saver=restoration_saver)

但这给出了以下错误:

RuntimeError: Init operations did not make model ready for local_init. Init op: group_deps, init fn: None, error: Variables not initialized: conv2a/T, conv2b/T, [...]



.. 其中错误消息中列出的未初始化变量是我的“init_vars”列表中的变量。

SessionManager.prepare_session() 引发了异常。该方法的源代码似乎表明,如果 session 是从检查点恢复的,则不会运行 init_op。所以看起来你可以有恢复的变量或初始化的变量,但不能两者兼而有之。

最佳答案

好吧,正如我所怀疑的,通过基于现有的 tf.training.SessionManager 实现一个新的 RefinementSessionManager 类,我得到了我想要的东西。这两个类几乎相同,除了我修改了 prepare_session 方法以调用 init_op,无论模型是否从检查点加载。

这允许我从检查点加载变量列表并初始化 init_op 中的剩余变量。

我的 prepare_session 方法是这样的:

  def prepare_session(self, master, init_op=None, saver=None,
checkpoint_dir=None, wait_for_checkpoint=False,
max_wait_secs=7200, config=None, init_feed_dict=None,
init_fn=None):

sess, is_loaded_from_checkpoint = self._restore_checkpoint(
master,
saver,
checkpoint_dir=checkpoint_dir,
wait_for_checkpoint=wait_for_checkpoint,
max_wait_secs=max_wait_secs,
config=config)

# [removed] if not is_loaded_from_checkpoint:
# we still want to run any supplied initialization on models that
# were loaded from checkpoint.

if not is_loaded_from_checkpoint and init_op is None and not init_fn and self._local_init_op is None:
raise RuntimeError("Model is not initialized and no init_op or "
"init_fn or local_init_op was given")
if init_op is not None:
sess.run(init_op, feed_dict=init_feed_dict)
if init_fn:
init_fn(sess)

# [...]

希望这对其他人有帮助。

关于tensorflow - 如何使用 tf.train.MonitoredTrainingSession 仅恢复某些变量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43336553/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com