gpt4 book ai didi

python - 如何查找ndarray中的所有argmax

转载 作者:行者123 更新时间:2023-12-01 03:31:38 25 4
gpt4 key购买 nike

我有一个二维 NumPy ndarray。

array([[  0.,  20.,  -2.],
[ 2., 1., 0.],
[ 4., 3., 20.]])

如何获取最大元素的所有索引?所以我想作为输出数组([0,1],[2,2])。

最佳答案

使用np.argwhere关于最大平等掩码 -

np.argwhere(a == a.max())

示例运行 -

In [552]: a   # Input array
Out[552]:
array([[ 0., 20., -2.],
[ 2., 1., 0.],
[ 4., 3., 20.]])

In [553]: a == a.max() # Max equality mask
Out[553]:
array([[False, True, False],
[False, False, False],
[False, False, True]], dtype=bool)

In [554]: np.argwhere(a == a.max()) # array of row, col indices of max-mask
Out[554]:
array([[0, 1],
[2, 2]])

如果您使用 float ,您可能需要使用一些容差。因此,考虑到这一点,您可以使用 np.isclose它有一些默认的绝对和相对容差值。这将替换之前的 a == a.max() 部分,就像这样 -

In [555]: np.isclose(a, a.max())
Out[555]:
array([[False, True, False],
[False, False, False],
[False, False, True]], dtype=bool)

关于python - 如何查找ndarray中的所有argmax,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/40897809/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com