gpt4 book ai didi

python - 如何检查 tf.estimator.inputs.numpy_input_fn 的内容?

转载 作者:太空宇宙 更新时间:2023-11-04 02:24:28 25 4
gpt4 key购买 nike

我想在一组数据上反复训练我的 tensorflow 图,我想 tf.estimator.inputs.numpy_input_fn可能是我要找的。我发现批量大小、重复次数、时期和迭代器之间的区别令人难以置信的困惑,所以我开始尝试检查我的数据集的内容,试图弄清楚到底发生了什么。但是,每当我尝试这样做时,我的程序就会挂起。

这是我想出的最小的测试用例来重现这个:

import tensorflow as tf
import numpy

class TestMock(tf.test.TestCase):
def test(self):
inputs = numpy.array(range(10))
targets = numpy.array(range(10,20))

input_fn = tf.estimator.inputs.numpy_input_fn(
x=inputs,
y=targets,
batch_size=1,
num_epochs=2,
shuffle=False)

print input_fn()
with self.test_session() as sess:
# sess.run(input_fn()[0]) # it'll hang if I run this
pass

if __name__ == '__main__':
tf.test.main()

这个程序输出

(<tf.Tensor 'fifo_queue_DequeueUpTo:1' shape=(?,) dtype=int64>, <tf.Tensor 'fifo_queue_DequeueUpTo:2' shape=(?,) dtype=int64>)

这似乎是合理的,但是当我尝试运行该 sess.run 行时,我的程序就会卡住,我必须终止该进程。我在这里做错了什么?

我想做的是确保我输入流程的数据实际上是我认为的那样,但我认为如果没有检查数据的能力我无法做到这一点。

最佳答案

从上面的打印语句我们可以推断input_fn返回queue ops,我们需要使用start_queue_runners and Coordinator来运行它们:

 features_op, labels_op = input_fn()
with tf.Session() as sess:
# initialise and start the queues.
sess.run(tf.local_variables_initializer())

coordinator = tf.train.Coordinator()
_ = tf.train.start_queue_runners(coord=coordinator)

print(sess.run([features_op, labels_op]))

#[array([0]), array([10])]

关于python - 如何检查 tf.estimator.inputs.numpy_input_fn 的内容?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50789693/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com