gpt4 book ai didi

python - 将预测转换为具有阈值的标签

转载 作者:太空宇宙 更新时间:2023-11-04 02:08:45 27 4
gpt4 key购买 nike

假设我有以下 numpy 数组

import numpy as np

arr = np.array([[0.2, 0.8], [0.99, 0.01], [0.08, 0.92]])
arr
Out[57]:
array([[0.2 , 0.8 ],
[0.99, 0.01],
[0.08, 0.92]])

如果我想将此输出转换为“类”(或每行中最大值的索引),我只需使用:

arr.argmax(axis=1)
Out[58]: array([1, 0, 1], dtype=int64)

问题是,我想限制某个阈值。对于示例,我们使用 0.9。因此,不满足阈值约束的每一行都将返回标签 -1。

上面示例的输出将是 [-1, 0, 1](因为 0.8 和 0.2 都不大于 0.9)。

最符合 Python 风格的方法是什么?希望(但不是必须)使用 numpy

最佳答案

您可以使用 np.where :

m = arr > 0.9
np.where(m.any(axis=1), m.argmax(axis=1), -1)
array([-1, 0, 1])

详情

(arr > 0.9) 返回一个具有相同形状的 ndarray,指示满足条件的位置:

array([[False, False],
[ True, False],
[False, True]])

m.argmax(axis=1) 返回 mTrue 的地方:

array([0, 0, 1])

np.where 将为那些满足 m.any(axis=1) 的行返回 m.argmax(axis=1) ,因此至少有一个元素大于阈值。这里 m.any(axis=1) 给出:

array([False,  True,  True])

否则np.where会返回-1

关于python - 将预测转换为具有阈值的标签,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54093719/

27 4 0