gpt4 book ai didi

python - 如何使用 tf.train.Checkpoint 保存大量变量

转载 作者:行者123 更新时间:2023-11-30 09:42:22 24 4
gpt4 key购买 nike

我可以通过以下方式在检查点( https://www.tensorflow.org/beta/guide/checkpoints#manually_inspecting_checkpoints )中保存两个变量(v1,v2)。但是如果我有很多变量(v3,v4 ...),该怎么办?如果我使用相同的方式(v1 = v1,v2 = v2,v3 = v3,v4 = v4 ..)有很多参数。有没有一种方便的方法来做到这一点?例如,由于 tf.train.Checkpoint 可以接受 keras 对象,我可以将所有变量放在一个对象中吗?

opt = tf.train.AdamOptimizer(0.1)
net = Net()

v1 = tf.get_variable("v1", [3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", [5], initializer = tf.zeros_initializer)


ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net,v2=v2,v1=v1)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
print("Restored from {}".format(manager.latest_checkpoint))
else:
print("Initializing from scratch.")

for example in toy_dataset():
loss = train_step(net, example, opt)
ckpt.step.assign_add(1)
if int(ckpt.step) % 10 == 0:
save_path = manager.save()
print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
print("loss {:1.2f}".format(loss.numpy()))

最佳答案

让我重新表述一下您的问题,以确保我理解正确:“如何在检查点中保存许多 tf 变量?”。

如果是这样,那么我的答案是将所有这些变量放入范围内,并按照此答案中建议的方式访问它们:https://stackoverflow.com/a/41642426/11708498

关于python - 如何使用 tf.train.Checkpoint 保存大量变量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57215717/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com