gpt4 book ai didi

machine-learning - Tensorflow:保存验证误差最小的模型

转载 作者:行者123 更新时间:2023-11-30 08:23:08 24 4
gpt4 key购买 nike

我使用 tensorflow 进行了训练工作,并得到了以下验证集损失曲线。第 6000 次迭代后网络开始过度拟合。所以我想在过度拟合之前获得模型。

loss

我的训练代码如下所示:

train_step = ......
summary = tf.scalar_summary(l1_loss.op.name, l1_loss)
summary_writer = tf.train.SummaryWriter("checkpoint", sess.graph)
saver = tf.train.Saver()
for i in xrange(20000):
batch = get_next_batch(batch_size)
sess.run(train_step, feed_dict = {x: batch.x, y:batch.y})
if (i+1) % 100 == 0:
saver.save(sess, "checkpoint/net", global_step = i+1)
summary_str = sess.run(summary, feed_dict=validation_feed_dict)
summary_writer.add_summary(summary_str, i+1)
summary_writer.flush()

训练完成后,仅保存了五个检查点(19600、19700、19800、19900、20000)。有没有办法让tensorflow根据验证错误保存checkpoint?

附注我知道 tf.train.Saver 有一个 max_to_keep 参数,它原则上可以保存所有检查点。但这不是我想要的(除非这是唯一的选择)。我希望保存者保留迄今为止验证损失最小的检查点。这可能吗?

最佳答案

您需要计算验证集的分类精度并跟踪迄今为止看到的最好的分类精度,只有在发现验证精度有所改进后才编写检查点。

如果数据集和/或模型很大,那么您可能必须将验证集分成批处理以适应内存中的计算。

本教程准确展示了如何做您想做的事情:

https://github.com/Hvass-Labs/TensorFlow-Tutorials/blob/master/04_Save_Restore.ipynb

它也可以作为短视频提供:

https://www.youtube.com/watch?v=Lx8JUJROkh0

关于machine-learning - Tensorflow:保存验证误差最小的模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39252901/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com