gpt4 book ai didi

tensorflow - 如何检查点 tf.data 数据集对象?

转载 作者:行者123 更新时间:2023-12-03 20:51:57 25 4
gpt4 key购买 nike

在训练期间设置检查点时(以防崩溃等),我保存了图形和参数,但不清楚如何对用于输入的新 tf.data 对象执行相同操作。

是否有一种直接的方法来检查这些点,以便我可以继续当前的纪元,或者恢复随机播放状态(可能来自种子?)

最佳答案

tf.contrib.data.make_saveable_from_iterator()函数接受一个 tf.data.Iterator 对象并返回一个“可保存对象”,可以使用 tf.train.Saver 保存该对象.它保存迭代器的整个状态,包括任何打乱的数据。

以下示例代码展示了如何将一个简单的迭代器添加到用于变量的相同检查点:

ds = tf.data.Dataset.range(10)
iterator = ds.make_initializable_iterator()

# [Build the training graph, using `iterator.get_next()` as the input.]

# Build the iterator SaveableObject.
saveable_obj = tf.contrib.data.make_saveable_from_iterator(iterator)

# Add the SaveableObject to the SAVEABLE_OBJECTS collection so
# it will be saved automatically using a Saver.
tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable_obj)

# Create a saver that saves all objects in the `tf.GraphKeys.SAVEABLE_OBJECTS`
# collection.
saver = tf.train.Saver()

with tf.Session() as sess:
while continue_training:

# [Perform training.]

if should_save_checkpoint:
saver.save(sess, ...)

请注意,迭代器检查点支持目前(从 TensorFlow 1.8 开始)处于实验状态,因此检查点格式可能会从一个版本更改为下一个版本。

关于tensorflow - 如何检查点 tf.data 数据集对象?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50265761/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com