gpt4 book ai didi

tensorflow - 如何仅在 Tensorflow 的检查点中恢复变量?

转载 作者:行者123 更新时间:2023-12-04 02:17:57 26 4
gpt4 key购买 nike

在 Tensorflow 中,我的模型基于预训练模型,我添加了更多变量并删除了预训练模型中的一些变量。当我从检查点文件恢复变量时,我必须明确指定我添加到图中需要排除的所有变量。例如,我做了

exclude = # explicitly list all variables to exclude
variables_to_restore = slim.get_variables_to_restore(exclude=exclude)
saver = tf.train.Saver(variables_to_restore)

有没有更简单的方法来做到这一点?即,只要变量不在检查点中,就不要尝试恢复。

最佳答案

您应该首先找出所有有用的变量(也意味着在您的图中),然后从检查点而不是从检查点添加两者的交集的联合集。

variables_can_be_restored = list(set(tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)).intersection(tf.train.list_variables(checkpoint_dir))) 

然后在定义这样的保护程序后恢复它:
temp_saver = tf.train.Saver(variables_can_be_restored)
ckpt_state = tf.train.get_checkpoint_state(checkpoint_dir, lastest_filename)
print('Loading checkpoint %s' % ckpt_state.model_checkpoint_path)
temp_saver.restore(sess, ckpt_state.model_checkpoint_path)

关于tensorflow - 如何仅在 Tensorflow 的检查点中恢复变量?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47731175/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com