gpt4 book ai didi

python - Tensorflow:从先前创建的模型中保存变量子集

转载 作者:行者123 更新时间:2023-11-28 19:08:47 27 4
gpt4 key购买 nike

我创建了一个包含大量变量的模型(模型 A)。我计划使用模型 A 中的一些层在新模型(模型 B)上使用模型 A 进行迁移学习。但是,模型 B 具有与模型 A 相同的架构,因此我不能只加载模型 A 之前的所有变量运行模型 B,否则会出现命名等错误。因此,我正在尝试创建一个新的 ckpt 文件,它只存储我想要的模型 A 的权重。然后我将使用这个新的 ckpt 文件加载到模型 B . 我有以下内容:

sess = tf.Session()
saver = tf.train.import_meta_graph('ModelA.ckpt.meta')
saver.restore(sess, 'ModelA.ckpt')

# I did not explicity name my variables in model A so I am just placing them in the list and taking the ones I want

store_list = []
for v in tf.trainable_variables():
store_list.append(v)

var_list={"W_1": store_list[0], "b_1": store_list[1]}
v2_saver=tf.train.Saver(var_list)
sess.run(tf.global_variables_initializer())
v2_saver.save(sess, 'model_A_subset.ckpt')

但是,当我恢复 model_A_subset.ckpt 时,我仍然拥有 ModelA.ckpt 中的所有变量。难道我做错了什么?有没有一种方法可以轻松地从 ModelA.ckpt 中删除我不需要的变量并使用它?

最佳答案

您确定不必要的变量在检查点中吗?我问是因为在恢复检查点之前你需要创建图表,如果你正在创建一个包含 A 的所有变量的图表,你就会遇到这个问题。

要检查检查点并查看实际情况,请尝试 inspect checkpoint tool .

关于python - Tensorflow:从先前创建的模型中保存变量子集,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43102199/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com