gpt4 book ai didi

python - 使用 Tensorflow 保存检查点

转载 作者:行者123 更新时间:2023-12-05 07:34:20 25 4
gpt4 key购买 nike

我的 CNN 模型有 3 个文件夹,它们是 train_data、val_data、test_data。

当我训练我的模型时,我发现准确度可能会有所不同,有时最后一个 epoch 并没有显示出最佳准确度。例如,上一个纪元的准确度是 71%,但我发现在较早的纪元中准确度更高。我想保存那个具有更高准确度的 epoch 的检查点,然后使用该检查点在 test_data

上预测我的模型

我在 train_data 上训练了我的模型并在 val_data 上进行了预测并保存了模型的检查点,如下所示:

    print("{} Saving checkpoint of model...". format(datetime.now()))
checkpoint_path = os.path.join(checkpoint_dir, 'model_epoch' + str(epoch) + '.ckpt')
save_path = saver.save(session, checkpoint_path)

在开始 tf.Session() 之前,我有这一行:

saver = tf.train.Saver()

我想知道如何保存具有更高准确度的最佳时期,然后将此检查点用于我的test_data

最佳答案

tf.train.Saver() 文档描述了以下内容:

saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'

请注意,如果您将 global_step 传递给保存程序,您将生成包含全局步骤编号的检查点文件。我通常每隔 X 分钟保存一次检查点,然后回来查看结果并在适当的步长值处选择一个检查点。如果您使用的是 tensorboard,您会发现这很直观,因为您的所有图表也可以按全局步骤显示。

https://www.tensorflow.org/api_docs/python/tf/train/Saver

关于python - 使用 Tensorflow 保存检查点,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50178866/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com