gpt4 book ai didi

python - 在 Numpy 数组中每行保留最大 N 个值

转载 作者:行者123 更新时间:2023-12-05 04:47:55 27 4
gpt4 key购买 nike

我需要在数组中每行保留最多 N (3) 个值。

a=np.array([[1,2,3,4],[8,7,6,5],[5,3,1,2]])
a
Out[135]:
array([[1, 2, 3, 4],
[8, 7, 6, 5],
[5, 3, 1, 2]])

它们的索引可以用np.partition来识别:

n=3
np.argpartition(a, -n, axis=1)[:,-n:]
Out[136]:
array([[1, 2, 3],
[2, 1, 0],
[3, 0, 1]], dtype=int64)

所以,我的问题是:我应该如何保留这些索引的值并将其他索引设置为零以获得:

Out[136]: 
array([[0, 2, 3, 4],
[8, 7, 6, 0],
[5, 3, 0, 2]])

最佳答案

a=np.array([[1,2,3,4],[8,7,6,5],[5,3,1,2]])

n=3
mask = np.argpartition(a, -n, axis=1) < a.shape[1] - n

a[mask] = 0

关于python - 在 Numpy 数组中每行保留最大 N 个值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/68290728/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com