gpt4 book ai didi

TensorFlow:关于 tf.argmax() 和 tf.equal() 的问题

转载 作者:行者123 更新时间:2023-12-03 08:46:56 24 4
gpt4 key购买 nike

我正在学习 TensorFlow,构建一个多层感知器模型。我正在研究一些示例,例如:https://github.com/aymericdamien/TensorFlow-Examples/blob/master/notebooks/3_NeuralNetworks/multilayer_perceptron.ipynb

然后我在下面的代码中有一些问题:

def multilayer_perceptron(x, weights, biases):
:
:

pred = multilayer_perceptron(x, weights, biases)
:
:

with tf.Session() as sess:
sess.run(init)
:
correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))

accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print ("Accuracy:", accuracy.eval({x: X_test, y: y_test_onehot}))

我想知道做什么 tf.argmax(prod,1)tf.argmax(y,1)意思和返回(类型和值)到底是什么?并且是 correct_prediction变量而不是实际值?

最后,我们如何获得 y_test_prediction tf session 中的数组(输入数据为 X_test 时的预测结果)?非常感谢!

最佳答案

tf.argmax(input, axis=None, name=None, dimension=None)

返回张量轴上具有最大值的索引。

输入是一个张量,轴描述输入张量的哪个轴要减少。对于向量,使用axis = 0。

对于您的特定情况,让我们使用两个数组并演示这一点
pred = np.array([[31, 23,  4, 24, 27, 34],
[18, 3, 25, 0, 6, 35],
[28, 14, 33, 22, 20, 8],
[13, 30, 21, 19, 7, 9],
[16, 1, 26, 32, 2, 29],
[17, 12, 5, 11, 10, 15]])

y = np.array([[31, 23, 4, 24, 27, 34],
[18, 3, 25, 0, 6, 35],
[28, 14, 33, 22, 20, 8],
[13, 30, 21, 19, 7, 9],
[16, 1, 26, 32, 2, 29],
[17, 12, 5, 11, 10, 15]])

正在评估 tf.argmax(pred, 1)给出一个张量,其求值将给出 array([5, 5, 2, 1, 3, 0])
正在评估 tf.argmax(y, 1)给出一个张量,其求值将给出 array([5, 5, 2, 1, 3, 0])
tf.equal(x, y, name=None) takes two tensors(x and y) as inputs and returns the truth value of (x == y) element-wise. 

按照我们的例子, tf.equal(tf.argmax(pred, 1),tf.argmax(y, 1))返回一个张量,其评估将给出 array(1,1,1,1,1,1) .

right_prediction 是一个张量,其评估将给出 0 和 1 的一维数组

y_test_prediction 可以通过执行 pred = tf.argmax(logits, 1) 获得

可以通过以下链接访问 tf.argmax 和 tf.equal 的文档。

tf.argmax() https://www.tensorflow.org/api_docs/python/math_ops/sequence_comparison_and_indexing#argmax

tf.equal() https://www.tensorflow.org/versions/master/api_docs/python/control_flow_ops/comparison_operators#equal

关于TensorFlow:关于 tf.argmax() 和 tf.equal() 的问题,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41708572/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com