gpt4 book ai didi

tensorflow - 如何使用 tf.MonitoredTrainingSession 在训练和验证数据集之间切换?

转载 作者:行者123 更新时间:2023-12-04 08:31:54 28 4
gpt4 key购买 nike

我想用feedable tensorflow Dataset API 中的迭代器设计,因此我可以在一些训练步骤后切换到验证数据。但是如果我切换到验证数据,它将结束整个 session 。

以下代码演示了我想要做什么:

import tensorflow as tf


graph = tf.Graph()
with graph.as_default():
training_ds = tf.data.Dataset.range(32).batch(4)
validation_ds = tf.data.Dataset.range(8).batch(4)

handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
handle, training_ds.output_types, training_ds.output_shapes)
next_element = iterator.get_next()

training_iterator = training_ds.make_initializable_iterator()
validation_iterator = validation_ds.make_initializable_iterator()


with graph.as_default():

with tf.train.MonitoredTrainingSession() as sess:
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())
sess.run(training_iterator.initializer)
count_training = 0
while not sess.should_stop():
x = sess.run(next_element, feed_dict={handle: training_handle})
count_training += 1
print('{} [training] {}'.format(count_training, x.shape))
# print(x)

# we do periodic validation
if count_training % 4 == 0:
sess.run(validation_iterator.initializer)
count_validation = 0
while not sess.should_stop():
y = sess.run(next_element, feed_dict={handle: validation_handle})
count_validation += 1
print(' {} [validation] {}'.format(count_validation, y.shape))
# print(y)

训练数据有32个元素,batched 4个,所以得到8个batches
我们每 4 步进行一次验证,所以我期望:
#  1 [training]
# 2 [training]
# 3 [training]
# 4 [training]
# 1 [validation]
# 2 [validation]
# 5 [training]
# 6 [training]
# 7 [training]
# 8 [training]
# 1 [validation]
# 2 [validation]

但它在第一次验证完成时停止:
# 1 [training]
# 2 [training]
# 3 [training]
# 4 [training]
# 1 [validation]
# 2 [validation]

那么,如何使用这个 feedable tf.MonitoredTrainingSession 中的迭代器?

最佳答案

我建议 catch tf.errors.OutOfRangeError在验证数据集的末尾引发(您也可以使用 repeat 数据集在官方 API 中检查 the processing multiple epochs section 以获得另一个解决方案):

while not sess.should_stop():
x = sess.run(next_element, feed_dict={handle: training_handle})
count_training += 1
print('{} [training] {}'.format(count_training, x.shape))

# we do periodic validation
if count_training % 4 == 0:
sess.run(validation_iterator.initializer)
count_validation = 0
while True:
try:
y = sess.run(next_element, feed_dict={handle: validation_handle})
count_validation += 1
print(' {} [validation] {}'.format(count_validation, y.shape))
except tf.errors.OutOfRangeError:
break

这段代码打印:
1 [training] (4,)  
2 [training] (4,)
3 [training] (4,)
4 [training] (4,)
1 [validation] (4,)
2 [validation] (4,)
5 [training] (4,)
6 [training] (4,)
7 [training] (4,)
8 [training] (4,)
1 [validation] (4,)
2 [validation] (4,)

关于tensorflow - 如何使用 tf.MonitoredTrainingSession 在训练和验证数据集之间切换?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49095849/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com