gpt4 book ai didi

python - 如何在numpy数组的给定行中保留N个最小元素?

转载 作者:行者123 更新时间:2023-12-01 12:02:15 26 4
gpt4 key购买 nike

给定一个二维 numpy 矩阵,如何保留每行中的 N 个最小元素并将其余元素更改为 0 (零)。

例如:N=3输入数组:

1   2   3   4   5
4 3 6 1 0
6 5 3 1 2

预期输出:
1   2   3   0   0
0 3 0 1 0
0 0 3 1 2

以下是我尝试过的代码,它可以工作:
# distance_matrix is the given 2D array
N=3
for i in range(distance_matrix.shape[0]):
n_th_largest = np.sort(distance_matrix[i])[N]
for j in range(distance_matrix.shape[1]):
distance_matrix[i][j] = np.where(distance_matrix[i][j]<n_th_largest,distance_matrix[i][j],0)

# return distance_matrix

但是,此操作涉及迭代每个元素。有没有更快的方法来使用 np.argsort() 来解决这个问题?或任何其他功能?

最佳答案

方法#1

这是一个 np.argpartition 性能效率 -

N = 3
newval = 0
np.put_along_axis(a,np.argpartition(a,N,axis=1)[:,N:],newval,axis=1)

说明:我们对输入数组进行分区以获得为 kth 分区的索引。参数在 np.argpartition .因此,基本上将其视为两个分区,第一个分区用于沿该轴的最小 N 个元素,另一个用于其余部分。我们需要重置第二个分区,我们选择 [:,N:]我们使用 np.put_along_axis 进行重置。

sample 运行 -
In [144]: a # input array
Out[144]:
array([[1, 2, 3, 4, 5],
[4, 3, 6, 1, 0],
[6, 5, 3, 1, 2]])

In [145]: np.put_along_axis(a,np.argpartition(a,3,axis=1)[:,3:],0,axis=1)

In [146]: a
Out[146]:
array([[1, 2, 3, 0, 0],
[0, 3, 0, 1, 0],
[0, 0, 3, 1, 2]])

方法#2

这是另一个 np.argpartition ,但只是对每行第 N 个最小元素进行切片,然后重置所有大于它的元素。因此,如果第 N 个最小元素有重复项,我们将使用此方法保留所有重复项。这是实现 -
a[a>=a[np.arange(len(a)), np.argpartition(a,3,axis=1)[:,3],None]] = 0

放大版本的时间 -
In [184]: a = np.array([[1,2,3,4,5],[4,3,6,1,0],[6,5,3,1,2]])

In [185]: a = np.repeat(a,10000,axis=0)

In [186]: %timeit np.put_along_axis(a,np.argpartition(a,3,axis=1)[:,3:],0,axis=1)
1.78 ms ± 5.89 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

In [187]: a = np.array([[1,2,3,4,5],[4,3,6,1,0],[6,5,3,1,2]])

In [188]: a = np.repeat(a,10000,axis=0)

In [189]: %timeit a[a>=a[np.arange(len(a)), np.argpartition(a,3,axis=1)[:,3],None]] = 0
1.54 ms ± 54.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

关于python - 如何在numpy数组的给定行中保留N个最小元素?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60871113/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com