gpt4 book ai didi

python - 用于同步训练和验证的可重新初始化迭代器

转载 作者:太空宇宙 更新时间:2023-11-03 15:50:08 24 4
gpt4 key购买 nike

我想使用 DatasetIterator 在训练期间对验证集进行评估。我想时不时地对一个(或几个)验证批处理进行评估——时不时地通常不是一个时代。

然而,可重新初始化的迭代器在重新初始化以切换其输入时会重新开始。例如

import tensorflow as tf

dataset_trn = tf.data.Dataset.range(10)
dataset_tst = tf.data.Dataset.range(10).map(lambda i: i + 1000)
iterator = tf.data.Iterator.from_structure(dataset_trn.output_types,
dataset_trn.output_shapes)
batch = iterator.get_next()
trn_init_op = iterator.make_initializer(dataset_trn)
tst_init_op = iterator.make_initializer(dataset_tst)

sess = tf.InteractiveSession()

for _ in range(2):
sess.run(trn_init_op)
for _ in range(5):
print(batch.eval())

sess.run(tst_init_op)
print(batch.eval())

返回

0
1
2
3
4
1000
0
1
2
3
4
1000

但我希望它能像那样恢复训练:

0
1
2
3
4
1000
5
6
7
8
9
1001

有办法实现吗?请注意,在实践中,批处理会被打乱,我希望它在相同的伪随机点恢复。

最佳答案

Feedable iterators应该有帮助,但他们很难合作。您需要创建占位符和字符串句柄:

dataset_trn = tf.data.Dataset.range(10)
dataset_tst = tf.data.Dataset.range(10).map(lambda i: i + 1000)

holder = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
holder, dataset_trn.output_types, dataset_trn.output_shapes)
batch = iterator.get_next()

trn_iter = dataset_trn.make_one_shot_iterator()
trn_handle = trn_iter.string_handle()

tst_iter = dataset_tst.make_one_shot_iterator()
tst_handle = tst_iter.string_handle()

with tf.Session() as sess:
for _ in range(2):

trn_string = sess.run(trn_handle)
tst_string = sess.run(tst_handle)

for _ in range(5):
print(sess.run(batch, feed_dict={holder: trn_string}))

print(sess.run(batch, feed_dict={holder: tst_string}))

关于python - 用于同步训练和验证的可重新初始化迭代器,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47431777/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com