gpt4 book ai didi

tensorflow - 如何从 tensorflow 中保存的检查点恢复特定范围的变量?

转载 作者:行者123 更新时间:2023-12-03 12:43:54 25 4
gpt4 key购买 nike

import tensorflow as tf
saver = tf.train.Saver()
saver.restore(...)

但是 saver.restore 只有恢复整个图的选项。我只想恢复特定范围内的那些变量。

提前致谢!

最佳答案

假设您在 InceptionV1 范围内拥有 Google 的 InceptionNet 模型,并且您想要加载它,但要重新训练范围 InceptionRetrained 中的最后一层除外。

假设您已经开始重新训练最后一层,并且您通过 saver2.save(session, 'last_layer.ckpt') 创建了 last_layer.ckpt 文件,下面是如何从两个检查点恢复网络。

saver1 = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='InceptionV1'))
saver1.restore(session, 'inception_model_from_google.ckpt')

saver2 = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='InceptionRetrained'))
saver2.restore(session, 'last_layer.ckpt')

如果您只重新训练最后一层,请不要忘记通过使用 var_list 参数调用优化器来禁用梯度在网络上的传播(节省时间)。

tf.train.Optimizer(0.0001).minimize(
loss, var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Inceptionretrained'))

关于tensorflow - 如何从 tensorflow 中保存的检查点恢复特定范围的变量?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42546365/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com