gpt4 book ai didi

TensorFlow:从多个检查点恢复变量

转载 作者:行者123 更新时间:2023-12-03 00:45:34 27 4
gpt4 key购买 nike

我有以下情况:

  • 我有用 2 个单独的脚本编写的 2 个模型:

  • 模型 A 由变量 a1a2a3 组成,并用 A.py< 编写

  • 模型 B 由变量 b1b2b3 组成,并用 B.py 编写

A.pyB.py 中,我都有一个 tf.train.Saver 来保存所有局部变量,我们分别调用检查点文件 ckptAckptB

我现在想要制作一个使用 a1b1 的模型 C。我可以通过使用 var_scope 使 a1 的变量名在 A 和 C 中使用完全相同的变量名(对于 b1 也是如此)。

问题是如何将 a1b1ckptAckptB 加载到模型 C 中?例如,以下内容会起作用吗?

saver.restore(session, ckptA_location)
saver.restore(session, ckptB_location)

如果您尝试两次恢复同一 session ,是否会引发错误?它会提示没有为额外变量(b2b3a2a3 分配“槽”吗? >),或者它会简单地恢复它可以的变量,并且仅在 C 中存在未初始化的其他变量时才提示?

我现在正在尝试编写一些代码来测试这个问题,但我希望看到解决这个问题的规范方法,因为在尝试重新使用一些预先训练的权重时经常会遇到这种情况。

谢谢!

最佳答案

如果您尝试使用保护程序(默认情况下代表所有六个变量)从不包含保护程序所包含的所有变量的检查点进行恢复,您将收到 tf.errors.NotFoundError代表。 (但请注意,只要所有请求的变量都存在于相应的文件中,您就可以在同一 session 中针对变量的任何子集多次调用 Saver.restore() 。 )

规范的方法是定义两个单独的 tf.train.Saver实例覆盖完全包含在单个检查点中的每个变量子集。例如:

saver_a = tf.train.Saver([a1])
saver_b = tf.train.Saver([b1])

saver_a.restore(session, ckptA_location)
saver_b.restore(session, ckptB_location)

根据代码的构建方式,如果您在本地范围内有指向名为 a1b1tf.Variable 对象的指针,你可以在这里停止阅读。

另一方面,如果变量 a1b1 是在单独的文件中定义的,您可能需要采取一些创造性的措施来检索指向这些变量的指针。尽管这并不理想,但人们通常会使用通用前缀,例如如下(假设变量名称为 "a1:0""b1:0" 分别):

saver_a = tf.train.Saver([v for v in tf.all_variables() if v.name == "a1:0"])
saver_b = tf.train.Saver([v for v in tf.all_variables() if v.name == "b1:0"])

最后一点:您不必付出巨大的努力来确保 A 和 C 中的变量具有相同的名称。您可以将 name-to-Variable 字典作为第一个传递tf.train.Saver 构造函数的参数,从而将检查点文件中的名称重新映射到代码中的 Variable 对象。如果 A.pyB.py 具有类似名称的变量,或者如果您想在 C.py 中组织模型代码,这会有所帮助来自 tf.name_scope() 中的这些文件.

关于TensorFlow:从多个检查点恢复变量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/35733917/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com