gpt4 book ai didi

tensorflow - 如何使用 tf.metrics 计算多标签分类的准确性?

转载 作者:行者123 更新时间:2023-12-03 12:53:50 25 4
gpt4 key购买 nike

我想用 tensorflow (tf.estimator.Estimator) 训练多标签分类模型。我需要在评估时输出准确性。但它似乎不适用于以下代码:

accuracy = tf.metrics.accuracy(labels=labels, predictions=preds)
metrics = {'accuracy': accuracy}

if mode == tf.estimator.ModeKeys.EVAL:
return tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops=metrics)
tf.metrics.accuracy不是多热结果。那么什么是多标签指标呢?

最佳答案

其实 tf.metrics.accuracy 也计算多标签分类的准确性。请参阅下面的示例:

import tensorflow as tf

labels = tf.constant([[1, 0, 0, 1],
[0, 1, 1, 1],
[1, 1, 0, 0],
[0, 0, 0, 1],
[1, 1, 0, 0]])

preds = tf.constant([[1, 0, 1, 1],
[0, 1, 1, 1],
[1, 1, 0, 0],
[0, 0, 0, 1],
[1, 1, 0, 0]])

acc, acc_op = tf.metrics.accuracy(labels, preds)

with tf.Session() as sess:
sess.run(tf.local_variables_initializer())
sess.run(tf.global_variables_initializer())
print(sess.run([acc, acc_op]))
print(sess.run([acc]))

如您所见,我们总共有 20 个标签,第一行中只有一个条目被错误标记,因此我们的准确率为 0.95%。

关于tensorflow - 如何使用 tf.metrics 计算多标签分类的准确性?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52832618/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com