gpt4 book ai didi

tensorflow - 如何使用来自 Tensorflow Dataset API 的可馈送迭代器和 MonitoredTrainingSession?

转载 作者:行者123 更新时间:2023-12-04 17:31:10 25 4
gpt4 key购买 nike

Tensorflow programmer's guide建议使用可馈送迭代器在训练和验证数据集之间切换,而无需重新初始化迭代器。主要是需要进给 handle 在它们之间进行选择。

如何与它一起使用 tf.train.MonitoredTrainingSession ?

以下方法失败并显示“RuntimeError: Graph is finalized and cannot be modified”。错误。

with tf.train.MonitoredTrainingSession() as sess:
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())

如何同时实现 MonitoredTrainingSession 的便利性和迭代训练和验证数据集?

最佳答案

我从 Tensorflow GitHub 问题中得到了答案 - https://github.com/tensorflow/tensorflow/issues/12859

解决方案是调用 iterator.string_handle()在创建 MonitoredSession 之前.

import tensorflow as tf
from tensorflow.contrib.data import Dataset, Iterator

dataset_train = Dataset.range(10)
dataset_val = Dataset.range(90, 100)

iter_train_handle = dataset_train.make_one_shot_iterator().string_handle()
iter_val_handle = dataset_val.make_one_shot_iterator().string_handle()

handle = tf.placeholder(tf.string, shape=[])
iterator = Iterator.from_string_handle(
handle, dataset_train.output_types, dataset_train.output_shapes)
next_batch = iterator.get_next()

with tf.train.MonitoredTrainingSession() as sess:
handle_train, handle_val = sess.run([iter_train_handle, iter_val_handle])

for step in range(10):
print('train', sess.run(next_batch, feed_dict={handle: handle_train}))

if step % 3 == 0:
print('val', sess.run(next_batch, feed_dict={handle: handle_val}))

Output:
('train', 0)
('val', 90)
('train', 1)
('train', 2)
('val', 91)
('train', 3)

关于tensorflow - 如何使用来自 Tensorflow Dataset API 的可馈送迭代器和 MonitoredTrainingSession?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46111072/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com