gpt4 book ai didi

tensorflow - Tensorflow 中的平衡准确度分数

转载 作者:行者123 更新时间:2023-11-30 08:47:53 27 4
gpt4 key购买 nike

我正在为高度不平衡的分类问题实现 CNN,并且我想在 tensorflow 中实现 custum 指标以使用“选择最佳模型”回调。具体来说,我想实现平衡的准确度分数,即每个类别的召回率的平均值(请参阅sklearn实现 here ),有人知道该怎么做吗?

最佳答案

我遇到了同样的问题,因此我实现了一个基于 SparseCategoricalAccuracy 的自定义类:

class BalancedSparseCategoricalAccuracy(keras.metrics.SparseCategoricalAccuracy):
def __init__(self, name='balanced_sparse_categorical_accuracy', dtype=None):
super().__init__(name, dtype=dtype)

def update_state(self, y_true, y_pred, sample_weight=None):
y_flat = y_true
if y_true.shape.ndims == y_pred.shape.ndims:
y_flat = tf.squeeze(y_flat, axis=[-1])
y_true_int = tf.cast(y_flat, tf.int32)

cls_counts = tf.math.bincount(y_true_int)
cls_counts = tf.math.reciprocal_no_nan(tf.cast(cls_counts, self.dtype))
weight = tf.gather(cls_counts, y_true_int)
return super().update_state(y_true, y_pred, sample_weight=weight)

这个想法是将每个类的权重设置为其大小成反比。

此代码会产生一些来自 Autograph 的警告,但我相信这些是 Autograph 错误,并且该指标似乎工作正常。

关于tensorflow - Tensorflow 中的平衡准确度分数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59339531/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com