gpt4 book ai didi

python - 恢复 Tensorflow 中的变量子集

转载 作者:太空狗 更新时间:2023-10-29 17:11:06 27 4
gpt4 key购买 nike

我正在使用 tensorflow 训练生成对抗网络 (GAN),基本上我们有两个不同的网络,每个网络都有自己的优化器。

self.G, self.layer = self.generator(self.inputCT,batch_size_tf)
self.D, self.D_logits = self.discriminator(self.GT_1hot)

...

self.g_optim = tf.train.MomentumOptimizer(self.learning_rate_tensor, 0.9).minimize(self.g_loss, global_step=self.global_step)

self.d_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5) \
.minimize(self.d_loss, var_list=self.d_vars)

问题是我先训练其中一个网络 (g),然后,我想一起训练 g 和 d。但是,当我调用加载函数时:

self.sess.run(tf.initialize_all_variables())
self.sess.graph.finalize()

self.load(self.checkpoint_dir)

def load(self, checkpoint_dir):
print(" [*] Reading checkpoints...")

ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
self.saver.restore(self.sess, ckpt.model_checkpoint_path)
return True
else:
return False

我有这样的错误(有更多回溯):

Tensor name "beta2_power" not found in checkpoint files checkpoint/MR2CT.model-96000

我可以恢复 g 网络并继续使用该函数进行训练,但是当我想从头开始 d 和存储模型中的 g 时,我遇到了那个错误。

最佳答案

要恢复变量子集,您必须创建一个新的 tf.train.Saver并在可选的 var_list 参数中传递一个特定的变量列表以恢复。

默认情况下,tf.train.Saver 将创建以下操作:(i) 在您调用 saver.save() 时保存图表中的每个变量。和 (ii) 在您调用 saver.restore() 时查找(按名称)给定检查点中的每个变量.虽然这适用于大多数常见场景,但您必须提供更多信息才能使用变量的特定子集:

  1. 如果您只想恢复变量的一个子集,您可以通过调用 tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=G_NETWORK_PREFIX) 获得这些变量的列表。 ,假设您将“g”网络放在一个公共(public)的 with tf.name_scope(G_NETWORK_PREFIX): 中或 tf.variable_scope(G_NETWORK_PREFIX):堵塞。然后,您可以将此列表传递给 tf.train.Saver 构造函数。

  2. 如果您想恢复变量的子集和/或检查点中的变量具有不同的名称,您可以将字典作为var_list争论。默认情况下,检查点中的每个变量都与一个相关联,这是其tf.Variable.name 属性的值。如果目标图中的名称不同(例如,因为您添加了范围前缀),您可以指定一个字典,将字符串键(在检查点文件中)映射到 tf.Variable 对象(在目标中图)。

关于python - 恢复 Tensorflow 中的变量子集,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41621071/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com