gpt4 book ai didi

python - 获取 numpy 数组中 N 个最高值的索引

转载 作者:太空宇宙 更新时间:2023-11-03 14:29:36 28 4
gpt4 key购买 nike

我的代码:

import numpy as np
N = 2
a = np.array([[0.5, 0.3, 0.2],
[0.2, 0.6, 0.2],
[0.3, 0.2, 0.7],
[np.nan, 0.2, 0.8],
[np.nan, np.nan, 0.8]
])

ind = np.argsort(np.where(np.isnan(a), -1, a), axis=1)[:, -N:]


a
Out[2]:
array([[ 0.5, 0.3, 0.2],
[ 0.2, 0.6, 0.2],
[ 0.3, 0.2, 0.7],
[ nan, 0.2, 0.8],
[ nan, nan, 0.8]])

ind
Out[3]:
array([[1, 0],
[2, 1],
[0, 2],
[1, 2],
[1, 2]], dtype=int64)

ind[:,1] 最高,ind[:,0] 第二高

除了最后一行有 2 个 nan 的情况之外,其他都很好。如果它是 nan 如何忽略第二高值?期望的输出是:

array([[1, 0],
[2, 1],
[0, 2],
[1, 2],
[nan, 2]], dtype=int64)

额外问题:在出现 a[1,:] 的情况下如何随机打破平局?

最佳答案

Advanced-index并检查 NaNs 为我们提供一个掩码,然后可以将其与 np.where 一起使用来进行选择,就像这样 -

In [244]: a_ind = a[np.arange(ind.shape[0])[:,None],ind]

In [245]: mask = np.isnan(a_ind)

In [246]: np.where(mask, np.nan, ind)
Out[246]:
array([[ 1., 0.],
[ 2., 1.],
[ 0., 2.],
[ 1., 2.],
[ nan, 2.]])

请注意,具有 NaN 的数组将被转换为 float dtype,因此最终输出也将是 float dtype。

关于python - 获取 numpy 数组中 N 个最高值的索引,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47393115/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com