gpt4 book ai didi

python - Tensorflow:如何保存变量并将它们加载到不同的变量?

转载 作者:行者123 更新时间:2023-12-01 08:50:59 24 4
gpt4 key购买 nike

假设我有两个相同的网络:AB。我保存了(使用 Saver)网络 A 的先前状态,现在我想将其加载到网络 B 中(所有操作都发生在相同的运行)。我怎样才能做到这一点?

最佳答案

让我举个例子。首先,让我们定义并保存一些变量:

import tensorflow as tf

v1 = tf.Variable(tf.ones(1), name='v1')
v2 = tf.Variable(2 * tf.ones(1), name='v2')

saver = tf.train.Saver(tf.trainable_variables())
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.save(sess, './tmp.ckpt')

现在,让我们在新图中定义一些具有相同名称的变量,并从检查点加载它们的值:

with tf.Graph().as_default():
assert len(tf.trainable_variables()) == 0
v1 = tf.Variable(tf.zeros(1), name='v1')
v2 = tf.Variable(tf.zeros(1), name='v2')

saver = tf.train.Saver(tf.trainable_variables())
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess, './tmp.ckpt')
print(sess.run([v1, v2]))

最后一行打印:

[array([1.], dtype=float32), array([2.], dtype=float32)]

关于python - Tensorflow:如何保存变量并将它们加载到不同的变量?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53129544/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com