gpt4 book ai didi

tensorflow - 哪个Loss function & Metrics更适合多标签分类?二元或分类交叉熵,为什么?

转载 作者:行者123 更新时间:2023-12-05 06:09:45 24 4
gpt4 key购买 nike

据我所知(如有错误请指正)

多标签分类(相互包含),即样本可能有 1 个以上的正确值(例如电影类型、疾病检测等)。

多类分类(互斥),即样本将始终具有 1 个正确值(例如猫或狗、对象检测等),这包括二元分类。

假设输出是单热编码。

这两种类型必须使用的损失函数和指标是什么?

                     loss func.          metrics
1. multi-label: (binary, categorical) (binary_accuracy, TopKCategorical accuracy, categorical_accuracy, AUC)
2. multi-class: (binary) (binary_accuracy,f1, recall, precision)

请从上表中告诉我哪些更合适,哪些是错误的,为什么?

最佳答案

如果您尝试使用多类分类,前提是标签 (y) 是一个热编码,请使用损失函数作为分类交叉熵并使用 adam 优化器(适用于大多数情况)。此外,在使用多类分类时,输出节点的数量应与类(或)标签的数量相同。假设您的模型要将输入分为 4 类,您可以按如下方式配置输出层。

model.add(4, activation = "softmax")

此外,忘记提及对于多类分类问题,应该在输出层使用 softmax 激活。

如果您的 y 不是热编码的,我建议您选择损失函数作为稀疏分类交叉熵。无需进行其他更改。

此外,我通常将数据拆分为测试数据和训练数据,然后像这样将它们提供给模型以获得每个时期的准确性。

history = model.fit(train_data, validation_data = test_data, epochs = 10)

希望它能解决您的问题。

关于tensorflow - 哪个Loss function & Metrics更适合多标签分类?二元或分类交叉熵,为什么?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64641430/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com