gpt4 book ai didi

python - tf.argmax() 返回意外结果

转载 作者:行者123 更新时间:2023-12-01 09:31:35 26 4
gpt4 key购买 nike

我最近正在制作一个基于 TensorFlow CNN、MNIST 数据集和服务器接口(interface)的项目。

在预测部分,我使用tf.argmax()来获取最大的logit,这将是预测值。但是,它返回的值似乎不是正确的答案。

预测函数大约是这样的:

    self.img = tf.reshape(tf.image.convert_image_dtype(img, tf.float32), shape=[1, 28, 28, 1])
self._create_model()

saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state('../checkpoints/')
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
saver.restore(sess, ckpt.model_checkpoint_path)

pred = tf.nn.softmax(self.logits)
prediction = tf.argmax(pred, 1)
logit = sess.run(pred)
result = sess.run(prediction)[0]
print(logit)
print(result)

return result

结果是:

127.0.0.1 - - [19/Apr/2018 21:35:47] "POST /index.html HTTP/1.1" 200 -
[[ 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]]
1

如您所见,logits 显示最大数字的索引是 5,但是 tf.argmax() 给了我 1> 相反。

顺便说一下,my model是基本的 MNIST CNN 模型,如您在链接中看到的。

那么这个tf.argmax()函数发生了什么,或者我的代码有问题?

最佳答案

由于您的 logit(pred) 和 result(prediction[0]) 来自两个不同的 sess.run,我想知道运行之间是否存在一些差异。例如,图中有一个迭代器将输入发送到模型。通过不同的运行,迭代器发送不同的数据,从而导致不同的预测。如果将 predprediction 放在同一个 sess.run 中,如下所示,将会很有趣:

logit, result = sess.run((pred, prediction))
print(logit)
print(result[0])

关于python - tf.argmax() 返回意外结果,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49923675/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com