gpt4 book ai didi

python - 有没有办法获取 numpy 数组(Python)每行的前 k 个值?

转载 作者:行者123 更新时间:2023-12-02 23:03:47 24 4
gpt4 key购买 nike

给定一个如下形式的 numpy 数组:

x = [[4.,3.,2.,1.,8.],[1.2,3.1,0.,9.2,5.5],[0.2,7.0,4.4,0.2,1.3]]

有没有办法在 python 中保留每行中的前 3 个值并将其他值设置为零(无需显式循环)。上面示例的结果将是

x = [[4.,3.,0.,0.,8.],[0.,3.1,0.,9.2,5.5],[0.0,7.0,4.4,0.0,1.3]]

示例代码

import numpy as np
arr = np.array([1.2,3.1,0.,9.2,5.5,3.2])
indexes=arr.argsort()[-3:][::-1]
a = list(range(6))
A=set(indexes); B=set(a)
zero_ind=(B.difference(A))
arr[list(zero_ind)]=0

输出:

array([0. , 0. , 0. , 9.2, 5.5, 3.2])

上面是我的一维 numpy 数组的示例代码(有很多行)。循环遍历 numpy 数组的每一行并重复执行相同的计算将非常昂贵。有没有更简单的方法?

最佳答案

这是一个完全矢量化的代码,没有 numpy 之外的第三方。它使用 numpy 的 argpartition有效地找到第 k 个值。例如,参见this answer用于其他用例。

def truncate_top_k(x, k, inplace=False):
m, n = x.shape
# get (unsorted) indices of top-k values
topk_indices = numpy.argpartition(x, -k, axis=1)[:, -k:]
# get k-th value
rows, _ = numpy.indices((m, k))
kth_vals = x[rows, topk_indices].min(axis=1)
# get boolean mask of values smaller than k-th
is_smaller_than_kth = x < kth_vals[:, None]
# replace mask by 0
if not inplace:
return numpy.where(is_smaller_than_kth, 0, x)
x[is_smaller_than_kth] = 0
return x

关于python - 有没有办法获取 numpy 数组(Python)每行的前 k 个值?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59402551/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com