gpt4 book ai didi

tensorflow - 如何正确使用tf.metrics.accuracy?

转载 作者:行者123 更新时间:2023-12-03 08:12:59 26 4
gpt4 key购买 nike

我使用accuracy中的tf.metrics函数来解决多个分类问题,并使用logits作为输入。

我的模型输出如下:

logits = [[0.1, 0.5, 0.4],
[0.8, 0.1, 0.1],
[0.6, 0.3, 0.2]]

我的标签是一种热编码 vector :
labels = [[0, 1, 0],
[1, 0, 0],
[0, 0, 1]]

当我尝试执行类似 tf.metrics.accuracy(labels, logits)的操作时,它永远不会给出正确的结果。我显然做错了,但我不知道是什么。

最佳答案

TL; DR

精度函数tf.metrics.accuracy根据它创建的两个局部变量totalcount来计算预测与标签匹配的频率,这两个变量用于计算logitslabels匹配的频率。

acc, acc_op = tf.metrics.accuracy(labels=tf.argmax(labels, 1), 
predictions=tf.argmax(logits,1))

print(sess.run([acc, acc_op]))
print(sess.run([acc]))
# Output
#[0.0, 0.66666669]
#[0.66666669]
  • acc(准确性):只需使用totalcount返回指标,就不会更新指标。
  • acc_op(更新):更新指标。

  • 要了解acc为什么返回 0.0,请仔细阅读以下详细信息。

    使用一个简单的示例详细信息:
    logits = tf.placeholder(tf.int64, [2,3])
    labels = tf.Variable([[0, 1, 0], [1, 0, 1]])

    acc, acc_op = tf.metrics.accuracy(labels=tf.argmax(labels, 1),
    predictions=tf.argmax(logits,1))

    初始化变量:

    由于 metrics.accuracy创建了两个局部变量 totalcount,因此我们需要调用 local_variables_initializer()对其进行初始化。
    sess = tf.Session()

    sess.run(tf.local_variables_initializer())
    sess.run(tf.global_variables_initializer())

    stream_vars = [i for i in tf.local_variables()]
    print(stream_vars)

    #[<tf.Variable 'accuracy/total:0' shape=() dtype=float32_ref>,
    # <tf.Variable 'accuracy/count:0' shape=() dtype=float32_ref>]

    了解更新操作和准确性计算:
    print('acc:',sess.run(acc, {logits:[[0,1,0],[1,0,1]]}))
    #acc: 0.0

    print('[total, count]:',sess.run(stream_vars))
    #[total, count]: [0.0, 0.0]

    尽管 totalcount为零,但上面给出的精度为0.0,尽管给出了匹配的输入。
    print('ops:', sess.run(acc_op, {logits:[[0,1,0],[1,0,1]]})) 
    #ops: 1.0

    print('[total, count]:',sess.run(stream_vars))
    #[total, count]: [2.0, 2.0]

    使用新输入时,将在调用更新op时计算精度。注意:由于所有logit和label都匹配,因此我们的精度为1.0,而本地变量 totalcount实际上给出了 total correctly predictedtotal comparisons made

    现在我们用新的输入(而不是更新操作)调用 accuracy:
    print('acc:', sess.run(acc,{logits:[[1,0,0],[0,1,0]]}))
    #acc: 1.0

    准确性调用不会使用新的输入来更新指标,它只是使用两个局部变量返回值。注意:在这种情况下,logit和标签不匹配。现在再次调用更新操作:
    print('op:',sess.run(acc_op,{logits:[[0,1,0],[0,1,0]]}))
    #op: 0.75
    print('[total, count]:',sess.run(stream_vars))
    #[total, count]: [3.0, 4.0]

    指标已更新为新输入

    有关培训期间如何使用指标以及在验证期间如何重置指标的更多信息,请访问 here

    关于tensorflow - 如何正确使用tf.metrics.accuracy?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46409626/

    26 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com