gpt4 book ai didi

python - 如何选择类别概率的最佳阈值?

转载 作者:行者123 更新时间:2023-11-30 08:58:00 24 4
gpt4 key购买 nike

我的神经网络输出是多标签分类的预测类概率表:

print(probabilities)

| | 1 | 3 | ... | 8354 | 8356 | 8357 |
|---|--------------|--------------|-----|--------------|--------------|--------------|
| 0 | 2.442745e-05 | 5.952136e-06 | ... | 4.254002e-06 | 1.894523e-05 | 1.033957e-05 |
| 1 | 7.685694e-05 | 3.252202e-06 | ... | 3.617730e-06 | 1.613792e-05 | 7.356643e-06 |
| 2 | 2.296657e-06 | 4.859554e-06 | ... | 9.934525e-06 | 9.244772e-06 | 1.377618e-05 |
| 3 | 5.163169e-04 | 1.044035e-04 | ... | 1.435158e-04 | 2.807420e-04 | 2.346930e-04 |
| 4 | 2.484626e-06 | 2.074290e-06 | ... | 9.958628e-06 | 6.002510e-06 | 8.434519e-06 |
| 5 | 1.297477e-03 | 2.211737e-04 | ... | 1.881772e-04 | 3.171079e-04 | 3.228884e-04 |

我使用阈值 (0.2) 将其转换为类标签,以衡量预测的准确性:

predictions = (probabilities > 0.2).astype(np.int)
print(predictions)

| | 1 | 3 | ... | 8354 | 8356 | 8357 |
|---|---|---|-----|------|------|------|
| 0 | 0 | 0 | ... | 0 | 0 | 0 |
| 1 | 0 | 0 | ... | 0 | 0 | 0 |
| 2 | 0 | 0 | ... | 0 | 0 | 0 |
| 3 | 0 | 0 | ... | 0 | 0 | 0 |
| 4 | 0 | 0 | ... | 0 | 0 | 0 |
| 5 | 0 | 0 | ... | 0 | 0 | 0 |

我还有一个测试集:

print(Y_test)

| | 1 | 3 | ... | 8354 | 8356 | 8357 |
|---|---|---|-----|------|------|------|
| 0 | 0 | 0 | ... | 0 | 0 | 0 |
| 1 | 0 | 0 | ... | 0 | 0 | 0 |
| 2 | 0 | 0 | ... | 0 | 0 | 0 |
| 3 | 0 | 0 | ... | 0 | 0 | 0 |
| 4 | 0 | 0 | ... | 0 | 0 | 0 |
| 5 | 0 | 0 | ... | 0 | 0 | 0 |

问题:如何在 Python 中构建一个算法来选择最大化 roc_auc_score(average = 'micro') 或其他指标的最佳阈值?

也许可以在 Python 中构建手动函数来优化阈值,具体取决于准确性指标。

最佳答案

我假设您的真实标签是Y_test,预测是预测

根据预测阈值优化roc_auc_score(average = 'micro')似乎没有意义,因为AUC是根据预测的排名方式计算的,因此需要预测作为[0,1]中的浮点值。

因此,我将讨论accuracy_score

您可以使用scipy.optimize.fmin :

import scipy
from sklearn.metrics import accuracy_score

def thr_to_accuracy(thr, Y_test, predictions):
return -accuracy_score(Y_test, np.array(predictions>thr, dtype=np.int))

best_thr = scipy.optimize.fmin(thr_to_accuracy, args=(Y_test, predictions), x0=0.5)

关于python - 如何选择类别概率的最佳阈值?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52093388/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com