gpt4 book ai didi

python - 如何使用 TensorFlow 从 MNIST 获得对一张图像的预测?

转载 作者:太空宇宙 更新时间:2023-11-04 04:48:20 24 4
gpt4 key购买 nike

我遵循了本教程 https://www.tensorflow.org/tutorials/layers并训练了一个模型来识别 MNIST 集中的手写数字。

以下代码按预期工作并为集合中的每个图像打印概率和类别

mnist = tf.contrib.learn.datasets.load_dataset("mnist")
train_data = mnist.train.images # Returns np.array
tf.reset_default_graph()
with tf.Session() as sess:

mnist_classifier = tf.estimator.Estimator(model_fn=cnn_model_fn, model_dir="model/")

pred = mnist_classifier.predict(input_fn=tf.estimator.inputs.numpy_input_fn(
x={"x": train_data},
shuffle=False))

for p in pred:
print(p)

但是,当我尝试只预测一张图像时

mnist_classifier.predict(input_fn=tf.estimator.inputs.numpy_input_fn(
x={"x": train_data[0]},
shuffle=False))

我的程序失败并且 TensorFlow 报告

InvalidArgumentError: Input to reshape is a tensor with 128 values,
but the requested shape requires a multiple of 784

这让我很困惑,因为当我打印集合中第一张图像的长度时,它报告 784

print("length of input: {}".format(len(train_data[0]))

我如何获得对一张图片的预测?

最佳答案

这很可能与您在创建单项数据集时删除了批处理维度有关。我的意思是你应该使用

mnist_classifier.predict(input_fn=tf.estimator.inputs.numpy_input_fn(
x={"x": np.array([train_data[0])]},
shuffle=False))

相反。请注意围绕 train_data[0] 的附加列表。这将采用形状为 [1, 784] 的数组并创建一个包含一个元素的数据集,这又是一个包含 784 个元素的向量。正如您现在的代码一样,您基本上是在创建一个包含 784 个元素的数据集,每个元素都是一个数字。这会导致形状不匹配。

关于python - 如何使用 TensorFlow 从 MNIST 获得对一张图像的预测?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48975337/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com