gpt4 book ai didi

python - 在同一 Tensorflow session 中从 Saver 加载两个模型

转载 作者:太空狗 更新时间:2023-10-29 22:03:26 27 4
gpt4 key购买 nike

我有两个网络:一个生成输出的 Model 和一个对输出评分的 Adversary

两者都是单独训练的,但现在我需要在一个 session 中合并它们的输出。

我已经尝试实现这篇文章中提出的解决方案:Run multiple pre-trained Tensorflow nets at the same time

我的代码

with tf.name_scope("model"):
model = Model(args)
with tf.name_scope("adv"):
adversary = Adversary(adv_args)

#...

with tf.Session() as sess:
tf.global_variables_initializer().run()

# Get the variables specific to the `Model`
# Also strip out the surperfluous ":0" for some reason not saved in the checkpoint
model_varlist = {v.name.lstrip("model/")[:-2]: v
for v in tf.global_variables() if v.name[:5] == "model"}
model_saver = tf.train.Saver(var_list=model_varlist)
model_ckpt = tf.train.get_checkpoint_state(args.save_dir)
model_saver.restore(sess, model_ckpt.model_checkpoint_path)

# Get the variables specific to the `Adversary`
adv_varlist = {v.name.lstrip("avd/")[:-2]: v
for v in tf.global_variables() if v.name[:3] == "adv"}
adv_saver = tf.train.Saver(var_list=adv_varlist)
adv_ckpt = tf.train.get_checkpoint_state(adv_args.save_dir)
adv_saver.restore(sess, adv_ckpt.model_checkpoint_path)

问题

对函数 model_saver.restore() 的调用似乎什么也没做。在另一个模块中,我使用了一个带有 tf.train.Saver(tf.global_variables()) 的保护程序,它可以很好地恢复检查点。

模型有 model.tvars = tf.trainable_variables()。为了检查发生了什么,我使用 sess.run() 在恢复前后提取了 tvars。每次使用初始随机分配的变量并且不分配来自检查点的变量。

关于为什么 model_saver.restore() 似乎什么都不做有什么想法吗?

最佳答案

解决这个问题花了很长时间,所以我发布了我可能不完美的解决方案,以防其他人需要它。

为了诊断问题,我手动遍历了每个变量并一一分配给它们。然后我注意到在分配变量后名称会改变。此处对此进行了描述:TensorFlow checkpoint save and read

根据那篇文章中的建议,我在各自的图表中运行了每个模型。这也意味着我必须在其自己的 session 中运行每个图表。这意味着以不同的方式处理 session 管理。

首先我创建了两个图表

model_graph = tf.Graph()
with model_graph.as_default():
model = Model(args)

adv_graph = tf.Graph()
with adv_graph.as_default():
adversary = Adversary(adv_args)

然后两个session

adv_sess = tf.Session(graph=adv_graph)
sess = tf.Session(graph=model_graph)

然后我在每个 session 中初始化变量并分别恢复每个图

with sess.as_default():
with model_graph.as_default():
tf.global_variables_initializer().run()
model_saver = tf.train.Saver(tf.global_variables())
model_ckpt = tf.train.get_checkpoint_state(args.save_dir)
model_saver.restore(sess, model_ckpt.model_checkpoint_path)

with adv_sess.as_default():
with adv_graph.as_default():
tf.global_variables_initializer().run()
adv_saver = tf.train.Saver(tf.global_variables())
adv_ckpt = tf.train.get_checkpoint_state(adv_args.save_dir)
adv_saver.restore(adv_sess, adv_ckpt.model_checkpoint_path)

从这里开始,每当需要每个 session 时,我都会在该 session 中使用 with sess.as_default(): 包装任何 tf 函数。最后我手动关闭 session

sess.close()
adv_sess.close()

关于python - 在同一 Tensorflow session 中从 Saver 加载两个模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41607144/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com