gpt4 book ai didi

python - 在numpy的混淆矩阵中获取与每种错误类型实例对应的索引

转载 作者:太空宇宙 更新时间:2023-11-03 13:18:42 25 4
gpt4 key购买 nike

我不只是获得一个混淆矩阵,还希望能够获得犯下特定类型错误的实例的索引(或数组掩码)。因此,例如,我想查看在属于 class-0 时被预测为 class-2 的实例,等等。

我可以使用列表理解轻松获得数组掩码:

import numpy as np

y_true, y_pred = np.array([0, 1, 0, 2, 1, 1]), np.array([0, 0, 0, 2, 1, 2])
np.array([[np.logical_and(y_true==r, y_pred==c) for c in xrange(3)] for r in xrange(3)])

这会产生:

[[[ True False  True False False False]
[False False False False False False]
[False False False False False False]]

[[False True False False False False]
[False False False False True False]
[False False False False False True]]

[[False False False False False False]
[False False False False False False]
[False False False True False False]]]

(要获取索引,我可以使用 np.where())。上面对应的是混淆矩阵:

[[2 0 0]
[1 1 1]
[0 0 1]]

但是,我想问一下是否有 numpy-thonic 单行代码来帮助我消除嵌套列表理解?

最佳答案

要将这些花哨的令人困惑的索引解决方案之一添加到组合中,您还可以:

>>> y_true = np.array([0, 1, 0, 2, 1, 1])
>>> y_pred = np.array([0, 0, 0, 2, 1, 2])
>>> out = np.zeros((3, 3, len(y_true)), dtype=np.bool)
>>> out[y_true, y_pred, np.arange(len(y_true))] = True
>>> out
array([[[ True, False, True, False, False, False],
[False, False, False, False, False, False],
[False, False, False, False, False, False]],

[[False, True, False, False, False, False],
[False, False, False, False, True, False],
[False, False, False, False, False, True]],

[[False, False, False, False, False, False],
[False, False, False, False, False, False],
[False, False, False, True, False, False]]], dtype=bool)

您可以获得在最后一个轴上对上述矩阵求和的混淆矩阵,但如果这就是您所追求的,最好直接使用 np.bincount 构建它:

>>> np.bincount(y_pred + 3*y_true, minlength=9).reshape(3,3)
array([[2, 0, 0],
[1, 1, 1],
[0, 0, 1]], dtype=int64)

SciPy 的 sparse_coo 矩阵将重复索引相加,因此以下内容也有效:

>>> sps.coo_matrix((np.ones_like(y_true, dtype=np.intp),
--- (y_true, y_pred)), shape=(3, 3)).A
array([[2, 0, 0],
[1, 1, 1],
[0, 0, 1]], dtype=int64)

关于python - 在numpy的混淆矩阵中获取与每种错误类型实例对应的索引,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/21153865/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com