gpt4 book ai didi

python - Tensorflow:对 tf.estimator.inputs.numpy_input_fn 函数进行故障排除

转载 作者:太空宇宙 更新时间:2023-11-03 11:18:57 30 4
gpt4 key购买 nike

我正在运行来自 text classification 的一些教程代码

我可以运行脚本并且它有效,但是当我尝试逐行运行它以试图了解每个步骤在做什么时,我在这一步有点困惑:

test_input_fn = tf.estimator.inputs.numpy_input_fn(
x={WORDS_FEATURE: x_test},
y=y_test,
num_epochs=1,
shuffle=False)
classifier.train(input_fn=train_input_fn, steps=100)

我从概念上知道 train_input_fn 正在向训练函数提供数据,但我如何手动调用此 fn 来检查其中的内容?

我跟踪了代码,发现 train_input_fn 函数将数据提供给以下 2 个变量:

features
Out[15]: {'words': <tf.Tensor 'random_shuffle_queue_DequeueMany:1' shape=(560, 10) dtype=int64>}

labels
Out[16]: <tf.Tensor 'random_shuffle_queue_DequeueMany:2' shape=(560,) dtype=int32>

当我尝试通过执行 sess.run(features) 来评估 features 变量时,我的终端似乎卡住了并​​停止响应。

检查这些变量内容的正确方法是什么?

谢谢!

最佳答案

基于numpy_input_fn documentation和行为(挂起)我想底层实现取决于队列运行器。队列运行器未启动时会发生挂起。尝试根据 this guide 将您的 session 运行脚本修改为如下内容:

with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
try:
for step in xrange(1000000):
if coord.should_stop():
break
features_data = sess.run(features)
print(features_data)

except Exception, e:
# Report exceptions to the coordinator.
coord.request_stop(e)
finally:
# Terminate as usual. It is safe to call `coord.request_stop()` twice.
coord.request_stop()
coord.join(threads)

或者,我鼓励您查看 tf.data.Dataset 接口(interface)(在 tensorflow 1.3 或更早版本中可能是 tf.contrib.data.Dataset)。您可以获得类似的输入/标签张量,而无需使用 Dataset.from_tensor_slices 的队列。创建稍微复杂一些,但接口(interface)更加灵活,并且实现不使用队列运行器,这意味着 session 运行要简单得多。

import tensorflow as tf
import numpy as np

x_data = np.random.random((100000, 2))
y_data = np.random.random((100000,))

batch_size = 2
buff = 100


def input_fn():
# possible tf.contrib.data.Dataset.from... in tf 1.3 or earlier
dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data))
dataset = dataset.repeat().shuffle(buff).batch(batch_size)
x, y = dataset.make_one_shot_iterator().get_next()
return x, y


x, y = input_fn()
with tf.Session() as sess:
print(sess.run([x, y]))

关于python - Tensorflow:对 tf.estimator.inputs.numpy_input_fn 函数进行故障排除,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46762932/

30 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com