gpt4 book ai didi

python - Tensorflow 从字符串句柄创建第二个迭代器 - GetNext() 失败,因为未初始化

转载 作者:行者123 更新时间:2023-12-05 07:32:26 26 4
gpt4 key购买 nike

您好,我的数据集迭代器突然遇到问题。我已经看到关于堆栈溢出的类似问题,但没有一个能够对我的情况有所帮助,所以我将其发布在这里。

当我在训练后创建验证迭代器时,我的代码运行完美。但现在我想看看损失在测试集上的表现如何,因此需要训练和测试数据集。无论如何,当我尝试运行我的代码时,它总是说我的第二个迭代器尚未初始化,我相信它已经初始化了。我几乎尝试了一切,使用 variable_scape,重命名变量等。如果有人可以看看我的代码并告诉我我哪里出错了?我非常关注 tensorflow example在关于从字符串句柄创建迭代器的文档中。

 def run(self, model="NN", use_gazemap=False):
# Input with gazemap or without

graph_input = self.projection(use_gazemap=self.use_gazemap)
self.predictions = self.classification_graph_nn(graph_input)

handle = tf.placeholder(tf.string, shape=[])
# Valid Dataset
valid_size = 65268
self.valid_iterator, self.valid_dataset = load_data("valid",
self.batch_size, "valid.tfrecord")


#Train Dataset
train_size = 58212
self.train_iterator, self.train_dataset = load_data("train",
self.batch_size, "train.tfrecord")


# Iterator
iterator = tf.data.Iterator.from_string_handle(handle,
self.train_dataset.output_types,
self.train_dataset.output_shapes)

next_element = iterator.get_next()


valid_handle = self.session.run(self.valid_iterator.string_handle())
training_handle = self.session.run(self.train_iterator.string_handle())
self.session.run(tf.global_variables_initializer())


run_options = tf.RunOptions(report_tensor_allocations_upon_oom=True)

# Summary
with tf.variable_scope('logging'):
tf.summary.scalar('current_cost', self.loss)
tf.summary.scalar('learning_rate', self.learning_rate)
summary = tf.summary.merge_all()

training_writer = tf.summary.FileWriter(
'./logs/training', self.session.graph)
testing_writer = tf.summary.FileWriter('.logs/testing', self.session.graph)


# Training model
for epoch in range(hparams.num_epochs):
self.session.run(self.train_iterator.initializer)
for it in range(train_size / hparams.batch_size):

# Training
frames, c3d, labels, gaze_gt, gaze_pred = self.session.run(
next_element, feed_dict={handle: training_handle})
feed_dict = {self.c3d: c3d,
self.gazemap: gaze_gt, self.labels: labels}
loss, _, global_step, learning_rate, training_summary = self.session.run(
[self.loss, self.train_op, self.global_step, self.learning_rate, summary], feed_dict=feed_dict, options=run_options)

# Testing
frames, c3d, labels, gaze_gt, gaze_pred = self.session.run(next_element, feed_dict={handle: valid_handle})
feed_dict = {self.c3d: c3d,
self.gazemap: gaze_gt, self.labels: labels}
test_loss, testing_summary = self.session.run(
[self.loss, summary], feed_dict=feed_dict, options=run_options)

if global_step % self.steps_per_logprint == 0:
self.session.run(self.predictions,
feed_dict=feed_dict, options=run_options)
batch_score = self.evaluate(self.predictions.eval(
feed_dict=feed_dict), self.labels.eval(feed_dict=feed_dict))

最佳答案

您必须像初始化训练迭代器一样初始化验证迭代器。为此,请在纪元的开头添加此行:

self.session.run(self.valid_iterator.initializer)

关于python - Tensorflow 从字符串句柄创建第二个迭代器 - GetNext() 失败,因为未初始化,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51220714/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com