gpt4 book ai didi

python - 如何不在 Tensorflow 中重新初始化预训练加载的模型?

转载 作者:太空宇宙 更新时间:2023-11-03 14:40:43 25 4
gpt4 key购买 nike

我已使用以下代码加载了预训练模型(模型 1):

def load_seq2seq_model(sess):


with open(os.path.join(seq2seq_config_dir_path, 'config.pkl'), 'rb') as f:
saved_args = pickle.load(f)

# Initialize the model with saved args
model = Model1(saved_args)

#Inititalize Tensorflow saver
saver = tf.train.Saver()

# Checkpoint
ckpt = tf.train.get_checkpoint_state(seq2seq_config_dir_path)
print('Loading model: ', ckpt.model_checkpoint_path)

# Restore the model at the checkpoint
saver.restore(sess, ckpt.model_checkpoint_path)
return model

现在,我想从头开始训练另一个模型(模型 2),它将采用模型 1 的输出。但为此,我需要定义一个 session 并加载预训练的模型并初始化模型tf.initialize_all_variables()。因此,预训练模型也将被初始化。

谁能告诉我如何在正确获取预训练模型 Model 1 的输出后训练 Model 2 吗?

我正在尝试的内容如下 -

with tf.Session() as sess:
# Initialize all the variables of the graph
seq2seq_model = load_seq2seq_model(sess)
sess.run(tf.initialize_all_variables())
.... Rest of the training code goes here....

最佳答案

使用保护程序恢复的所有变量都不需要初始化。因此,您可以使用 tf.variables_initializer(var_list) 来仅初始化第二个网络的权重,而不是使用 tf.initialize_all_variables() 。

要获取第二个网络的所有权重列表,您可以在变量范围中创建 Model 2 网络:

with tf.variable_scope("model2"):
model2 = Model2(...)

然后使用

model_2_variables_list = tf.get_collection(
tf.GraphKeys.GLOBAL_VARIABLES,
scope="model2"
)

获取Model 2网络的变量列表。最后,您可以为第二个网络创建初始化程序:

init2 = tf.variables_initializer(model_2_variables_list)

with tf.Session() as sess:
# Initialize all the variables of the graph
seq2seq_model = load_seq2seq_model(sess)
sess.run(init2)
.... Rest of the training code goes here....

关于python - 如何不在 Tensorflow 中重新初始化预训练加载的模型?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46564427/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com