gpt4 book ai didi

python - 如何恢复预训练模型以初始化参数

转载 作者:行者123 更新时间:2023-11-28 19:01:04 24 4
gpt4 key购买 nike

我已经下载了一个带有预训练模型的网络。我在网络中添加了几个层和参数,我想使用这个预训练模型来初始化原始参数,并自己随机初始化新添加的参数。我使用此代码:

saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, "output/saver-test")
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())

但是我遇到了错误:“Key global_step not found in checkpoint”,这个错误是因为我有一些在预训练模型中不存在的新参数。但是我该如何解决这个问题呢?更何况,我想使用这段代码“sess.run(tf.global_variables_initializer())”来初始化新添加的参数,但是从预训练模型中提取的参数会被它覆盖吗?

最佳答案

发生这种情况是因为您的网络与加载的网络不完全匹配。您可以使用类似的选择性检查点加载程序:

  reader = tf.train.NewCheckpointReader(os.path.join(checkpoint_dir, ckpt_name))
restore_dict = dict()
for v in tf.trainable_variables():
tensor_name = v.name.split(':')[0]
if reader.has_tensor(tensor_name):
print('has tensor ', tensor_name)
restore_dict[tensor_name] = v

restore_dict['my_new_var_scope/my_new_var'] = self.get_my_new_var_variable()

其中 get_my_new_var_variable() 是这样的:

    def get_my_new_var_variable(self):
with tf.variable_scope("my_new_var_scope",reuse=tf.AUTO_REUSE):
my_new_var = tf.get_variable("my_new_var", dtype=tf.int32,initializer=tf.constant([23, 42]))
return my_new_var

加载权重:

self.saver = tf.train.Saver(restore_dict)
self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))

编辑:

请注意,为了避免覆盖加载的变量,您可以使用此方法:

def initialize_uninitialized(sess):
global_vars = tf.global_variables()
is_not_initialized = sess.run([tf.is_variable_initialized(var) for var in global_vars])
not_initialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f]
if len(not_initialized_vars):
sess.run(tf.variables_initializer(not_initialized_vars))

或者在加载变量之前简单地调用 tf.global_variables_initializer() 应该在这里工作。

关于python - 如何恢复预训练模型以初始化参数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52532150/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com