- Java锁的逻辑(结合对象头和ObjectMonitor)
- 还在用饼状图?来瞧瞧这些炫酷的百分比可视化新图形(附代码实现)⛵
- 自动注册实体类到EntityFrameworkCore上下文,并适配ABP及ABPVNext
- 基于Sklearn机器学习代码实战
作者:PrimiHub-Kevin 。
ROC 曲线是一种坐标图式的分析工具 ,是由二战中的电子和雷达工程师发明的,发明之初是用来侦测敌军飞机、船舰,后来被应用于医学、生物学、犯罪心理学.
如今, ROC 曲线已经被广泛应用于机器学习领域的模型评估 ,说到这里就不得不提到 Tom Fawcett 大佬,他一直在致力于推广 ROC 在机器学习领域的应用,他发布的论文 《An introduction to ROC analysis》 更是被奉为 ROC 的经典之作 (引用 2.2w 次), 知名机器学习库 scikit-learn 中的 ROC 算法就是参考此论文实现 ,可见其影响力! 。
不知道大多数人是否和我一样, 对于 ROC 曲线的理解只停留在调用 scikit-learn 库的函数 ,对于它的背后原理和公式所知甚少.
前几天我重读了《An introduction to ROC analysis》终于将 ROC 曲线彻底搞清楚了,独乐乐不如众乐乐!如果你也对 ROC 的算法及实现 感兴趣,不妨花些时间看完全文,相信你一定会有所收获! 。
下图中的 蓝色曲线就是 ROC 曲线 ,它常被用来评价二值分类器的优劣,即评估模型预测的准确度.
二值分类器,就是字面意思它会将数据分成两个类别(正/负样本)。例如:预测银行用户是否会违约、内容分为违规和不违规,以及广告过滤、图片分类等场景。 篇幅关系这里不做多分类 ROC 的讲解.
TPR: True positive rate; FPR: False positive rate 。
坐标系中纵轴为 TPR(真阳率/命中率/召回率)最大值为 1,横轴为 FPR(假阳率/误判率)最大值为 1,虚线为基准线(最低标准),蓝色的曲线就是 ROC 曲线。其中 ROC 曲线距离基准线越远 ,则说明该模型的 预测效果越好 .
考虑一个二分类模型, 负样本(Negative) 为 0,正样本(Positive) 为 1。即:
因此,将 \(y\) 与 \(\hat{y}\) 两两组合就会得到 4 种可能性,分别称为:
ROC 曲线的横坐标为 FPR(False Positive Rate),纵坐标为 TPR(True Positive Rate)。FPR 统计了所有负样本中 预测错误(FP) 的比例,TPR 统计了所有正样本中 预测正确(TP) 的比例,其计算公式如下,其中 # 表示统计个数,例如 #N 表示负样本的个数,#P 表示正样本的个数 。
\(\text{FPR}=\frac{\#\text{FP}}{\#\text{N}}\) , \(\text{TPR}=\frac{\#\text{TP}}{\#\text{P}}\) 。
下面举一个实际例子作为讲解,以下表 5 个样本为例, 讲解如何计算 FPR 和 TPR .
id | 真实标签 \(y\) | 预测标签 \(\hat{y}\) |
---|---|---|
1 | 1 | 1 |
2 | 1 | 0 |
3 | 0 | 0 |
4 | 1 | 1 |
5 | 0 | 1 |
正样本数 #P=3,负样本数 #N=2.
其中 \(y=0\) 且 \(\hat{y}=1\) 的样本有 1 个,即 #FP=1,所以 FPR=1/2=0.5 。
其中 \(y=1\) 且 \(\hat{y}=1\) 的样本有 2 个,即 #TP=2,所以 FPR=2/3 。
FPR 和 TPR 的取值范围均是 0 到 1 之间。对于 FPR,我们希望其越小越好。而对于 TPR,我们希望其越大越好.
至此,我们已经介绍完如何计算 FPR 和 TPR 的值,下面将会讲解如何绘制 ROC 曲线.
讲到这里,可能有的同学会问:ROC 不是一条曲线吗?讲了这么多它到底应该怎么画呢?下面将分为两部分讲解如何绘制 ROC 曲线,直接打通你的“任督二脉”彻底拿下 ROC 曲线:
如果说上面是“开胃小菜”,那下面就是正菜啦! 。
一般在二分类模型里(标签取值为 0 或 1),会默认设定一个阈值 (threshold)。当预测分数大于这个阈值时,输出 1,反之输出 0。我们可以通过调节这个阈值,改变模型预测的输出,进而画出 ROC 曲线.
以下面表格中的 20 个点为例,介绍如何人工画出 ROC 曲线,其中正样本和负样本都是 10 个,即 #P = #N = 10.
id | 真实标签 | 预测分数 | id | 真实标签 | 预测分数 |
---|---|---|---|---|---|
1 | 1 | .9 | 11 | 1 | .4 |
2 | 1 | .8 | 12 | 0 | .39 |
3 | 0 | .7 | 13 | 1 | .38 |
4 | 1 | .6 | 14 | 0 | .37 |
5 | 1 | .55 | 15 | 0 | .36 |
6 | 1 | .54 | 16 | 0 | .35 |
7 | 0 | .53 | 17 | 1 | .34 |
8 | 0 | .52 | 18 | 0 | .33 |
9 | 1 | .51 | 19 | 1 | .30 |
10 | 0 | .505 | 20 | 0 | .1 |
当设定阈值为 0.9 时,只有第一个点预测为 1,其余都为 0,故 #FP=0、#TP=1,计算出 FPR=0/10=0,TPR=1/10=0.1,画出点 (0,0.1) 。
当设定阈值为 0.8 时,只有前两个点预测为 1,其余都为 0,故 #FP=0、#TP=2,计算出 FPR=0/10=0,TPR=2/10=0.2,画出点 (0,0.2) 。
当设定阈值为 0.7 时,只有前三个点预测为 1,其余都为 0,故 #FP=1、#TP=2,计算出 FPR=1/10=0.1,TPR=2/10=0.2,画出点 (0.1,0.2).
以此类推,画出的 ROC 曲线如下:
因此,在画 ROC 曲线前,需要将 预测分数从大到小排序 ,然后将 预测分数依次设定为阈值 ,分别计算 FPR 和 TPR。而对于基准线,假设随机预测为正样本的概率为 \(x\) ,即 \(\Pr(\hat{y}=1)=x\) 由于 FPR 计算的是负样本中,预测为正样本的概率,因此 FPR= \(x\) (同理,TPR= \(x\) )。所以, 基准线为从点 (0, 0) 到 (1, 1) 的斜线 .
接下来,我们将结合代码讲解如何在 Python 中绘制 ROC 曲线.
下面的代码参考了 《An Introduction to ROC Analysis》 中的算法 1(伪代码)。值得一提的是, 知名机器学习库 scikit-learn 的 roc_curve 函数 也参考了这个算法.
下面我自己实现的 roc 函数可以理解为是简化版的 roc_curve ,这里的代码逻辑更加简洁易懂,算法的时间复杂度 \(O(n\log n)\) 。完整的代码如下:
# import numpy as np
def roc(y_true, y_score, pos_label):
"""
y_true:真实标签
y_score:模型预测分数
pos_label:正样本标签,如“1”
"""
# 统计正样本和负样本的个数
num_positive_examples = (y_true == pos_label).sum()
num_negtive_examples = len(y_true) - num_positive_examples
tp, fp = 0, 0
tpr, fpr, thresholds = [], [], []
score = max(y_score) + 1
# 根据排序后的预测分数分别计算fpr和tpr
for i in np.flip(np.argsort(y_score)):
# 处理样本预测分数相同的情况
if y_score[i] != score:
fpr.append(fp / num_negtive_examples)
tpr.append(tp / num_positive_examples)
thresholds.append(score)
score = y_score[i]
if y_true[i] == pos_label:
tp += 1
else:
fp += 1
fpr.append(fp / num_negtive_examples)
tpr.append(tp / num_positive_examples)
thresholds.append(score)
return fpr, tpr, thresholds
导入上面 3.1 表格中的数据,通过上面实现的 roc 方法,计算 ROC 曲线的坐标值.
import numpy as np
y_true = np.array(
[1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0]
)
y_score = np.array([
.9, .8, .7, .6, .55, .54, .53, .52, .51, .505,
.4, .39, .38, .37, .36, .35, .34, .33, .3, .1
])
fpr, tpr, thresholds = roc(y_true, y_score, pos_label=1)
最后,通过 Matplotlib 将计算出的 ROC 曲线坐标绘制成图.
import matplotlib.pyplot as plt
plt.plot(fpr, tpr)
plt.axis("square")
plt.xlabel("False positive rate")
plt.ylabel("True positive rate")
plt.title("ROC curve")
plt.show()
至此,ROC 的基础知识部分就全部讲完了,如果还想深入了解的同学可以继续往下看.
如果将上面的内容比作“正餐”,那这里就是妥妥干货了,打起精神冲鸭! 。
顾名思义,ROC 平均就是将多条 ROC 曲线“平均化”。那么,什么场景需要做 ROC 平均呢?例如: 横向联邦学习中,由于样本都在用户本地,服务器可以采用 ROC 平均的方式,计算近似的全局 ROC 曲线 .
ROC 的平均有两种方法: 垂直平均、阈值平均 ,下面将逐一进行讲解,并给出 Python 代码实现.
垂直平均(Vertical averaging)的思想是,选取一些 FPR 的点,计算其平均的 TPR 值。下面是论文中的算法描述的伪代码,看不懂可直接略过看 Python 代码实现部分.
下面是 Python 的代码实现:
# import numpy as np
def roc_vertical_avg(samples, FPR, TPR):
"""
samples:选取FPR点的个数
FPR:包含所有FPR的列表
TPR:包含所有TPR的列表
"""
nrocs = len(FPR)
tpravg = []
fpr = [i / samples for i in range(samples + 1)]
for fpr_sample in fpr:
tprsum = 0
# 将所有计算的tpr累加
for i in range(nrocs):
tprsum += tpr_for_fpr(fpr_sample, FPR[i], TPR[i])
# 计算平均的tpr
tpravg.append(tprsum / nrocs)
return fpr, tpravg
# 计算对应fpr的tpr
def tpr_for_fpr(fpr_sample, fpr, tpr):
i = 0
while i < len(fpr) - 1 and fpr[i + 1] <= fpr_sample:
i += 1
if fpr[i] == fpr_sample:
return tpr[i]
else:
return interpolate(fpr[i], tpr[i], fpr[i + 1], tpr[i + 1], fpr_sample)
# 插值
def interpolate(fprp1, tprp1, fprp2, tprp2, x):
slope = (tprp2 - tprp1) / (fprp2 - fprp1)
return tprp1 + slope * (x - fprp1)
阈值平均(Threshold averaging)的思想是,选取一些阈值的点,计算其平均的 FPR 和 TPR.
下面是 Python 的代码实现:
# import numpy as np
def roc_threshold_avg(samples, FPR, TPR, THRESHOLDS):
"""
samples:选取FPR点的个数
FPR:包含所有FPR的列表
TPR:包含所有TPR的列表
THRESHOLDS:包含所有THRESHOLDS的列表
"""
nrocs = len(FPR)
T = []
fpravg = []
tpravg = []
for thresholds in THRESHOLDS:
for t in thresholds:
T.append(t)
T.sort(reverse=True)
for tidx in range(0, len(T), int(len(T) / samples)):
fprsum = 0
tprsum = 0
# 将所有计算的fpr和tpr累加
for i in range(nrocs):
fprp, tprp = roc_point_at_threshold(FPR[i], TPR[i], THRESHOLDS[i], T[tidx])
fprsum += fprp
tprsum += tprp
# 计算平均的fpr和tpr
fpravg.append(fprsum / nrocs)
tpravg.append(tprsum / nrocs)
return fpravg, tpravg
# 计算对应threshold的fpr和tpr
def roc_point_at_threshold(fpr, tpr, thresholds, thresh):
i = 0
while i < len(fpr) - 1 and thresholds[i] > thresh:
i += 1
return fpr[i], tpr[i]
在我们的 PrimiHub 联邦学习模块 中,就实现了上述 ROC 平均方法.
本文由浅入深地详细介绍了 ROC 曲线算法,包含算法原理、公式、计算、源码实现和讲解,希望能够帮助读者一口气(看的时候可得喘气 😮💨)搞懂 ROC.
虽然 ROC 是个不起眼的知识点,但能网上能彻底讲清楚 ROC 的文章并不多。所以我又花时间重温了一遍 Tom Fawcett 的经典论文 《An introduction to ROC analysis》 ,并将论文的内容抽丝剥茧、配上通俗易懂的 Python 代码,最终写出了这篇文章。再次致敬🫡 Tom Fawcett,感谢他在机器学习领域的贡献! 。
我们是 PrimiHub 密码学专家团队,用心去写每一篇内容,让每一位点开文章的读者都能有所收获。我们的内容专注于隐私计算领域,偶尔也涉及下机器学习领域。 如果大家喜欢这个系列请留言告诉我们,它的姐妹篇 ACU 详解直接安排! 。
PrimiHub 一款由密码学专家团队打造的开源隐私计算平台,专注于分享数据安全、密码学、联邦学习、同态加密等隐私计算领域的技术和内容.
最后此篇关于小白也能看懂的ROC曲线详解的文章就讲到这里了,如果你想了解更多关于小白也能看懂的ROC曲线详解的内容请搜索CFSDN的文章或继续浏览相关文章,希望大家以后支持我的博客! 。
全称“Java Virtual Machine statistics monitoring tool”(statistics 统计;monitoring 监控;tool 工具) 用于监控虚拟机的各种运
主要是讲下Mongodb的索引的查看、创建、删除、类型说明,还有就是Explain执行计划的解释说明。 可以转载,但请注明出处。  
1>单线程或者单进程 相当于短链接,当accept之后,就开始数据的接收和数据的发送,不接受新的连接,即一个server,一个client 不存在并发。 2>循环服务器和并发服务器
详解 linux中的关机和重启命令 一 shutdown命令 shutdown [选项] 时间 选项: ?
首先,将json串转为一个JObject对象: ? 1
matplotlib官网 matplotlib库默认英文字体 添加黑体(‘SimHei')为绘图字体 代码: plt.rcParams['font.sans-serif']=['SimHei'
在并发编程中,synchronized关键字是常出现的角色。之前我们都称呼synchronized关键字为重量锁,但是在jdk1.6中对synchronized进行了优化,引入了偏向锁、轻量锁。本篇
一般我们的项目中会使用1到2个数据库连接配置,同程艺龙的数据库连接配置被收拢到统一的配置中心,由DBA统一配置和维护,业务方通过某个字符串配置拿到的是Connection对象。  
实例如下: ? 1
1. MemoryCahe NetCore中的缓存和System.Runtime.Caching很相似,但是在功能上做了增强,缓存的key支持object类型;提供了泛型支持;可以读缓存和单个缓存
argument是javascript中函数的一个特殊参数,例如下文,利用argument访问函数参数,判断函数是否执行 复制代码 代码如下: <script
一不小心装了一个Redis服务,开了一个全网的默认端口,一开始以为这台服务器没有公网ip,结果发现之后悔之莫及啊 某天发现cpu load高的出奇,发现一个minerd进程 占了大量cpu,googl
今天写这个是为了 提醒自己 编程过程 不仅要有逻辑 思想 还有要规范 代码 这样可读性 1、PHP 编程规范与编码习惯最主要的有以下几点: 1 文件说明 2 funct
摘要:虚拟机安装时一般都采用最小化安装,默认没有lspci工具。一台测试虚拟网卡性能的虚拟机,需要lspci工具来查看网卡的类型。本文描述了在一个虚拟机中安装lspci工具的具体步骤。 由于要测试
1、修改用户进程可打开文件数限制 在Linux平台上,无论编写客户端程序还是服务端程序,在进行高并发TCP连接处理时,最高的并发数量都要受到系统对用户单一进程同时可打开文件数量的限制(这是因为系统
目录 算术运算符 基本四则运算符 增量赋值运算符 自增/自减运算符 关系运算符 逻
如下所示: ? 1
MapperScannerConfigurer之sqlSessionFactory注入方式讲解 首先,Mybatis中的有一段配置非常方便,省去我们去写DaoImpl(Dao层实现类)的时间,这个
Linux的网络虚拟化是LXC项目中的一个子项目,LXC包括文件系统虚拟化,进程空间虚拟化,用户虚拟化,网络虚拟化,等等,这里使用LXC的网络虚拟化来模拟多个网络环境。 本文从基本的网络设备讲
? 1
我是一名优秀的程序员,十分优秀!