gpt4 book ai didi

python - 选择 NumPy 数组中每列中出现的所有前 K 值

转载 作者:太空宇宙 更新时间:2023-11-03 16:26:35 25 4
gpt4 key购买 nike

假设我有一个 NumPy 数组,如下所示:我的原始数组大小为 50K X8.5K。这是示例

array([[ 1. ,  2. ,  3. ],
[ 1. , 0.5, 2. ],
[ 2. , 3. , 1. ]])

现在我想要的是,对于每一列,仅保留前 K 个值(此处将 K 设为 2)并将其他值重新编码为零。

所以我期望的输出是这样的:

array([[ 1.,  2.,  3.],
[ 1., 0., 2.],
[ 2., 3., 0.]])

所以基本上,如果我们看到,我们会按降序对每个列值进行排序,然后检查该列的每个值是否不在该列的 k 个最大值中,然后将该值重新编码为零

我尝试了类似的方法,但出现错误

for x in range(e.shape[1]):
e[:,x]=map(np.where(lambda x: x in e[:,x][::-1][:2], x, 0), e[:,x])



2
3 for x in range(e.shape[1]):
----> 4 e[:,x]=map(np.where(lambda x: x in e[:,x][::-1][:2], x, 0), e[:,x])
5

TypeError: 'numpy.ndarray' object is not callable

目前我也在对每一列进行迭代。任何快速工作的解决方案,因为我有 50K 行和 8K 列,因此对每一列进行迭代,然后对每一列进行该列中每个值的映射,我想这将非常耗时。

请指教。

最佳答案

针对如此大型数组的性能,这里有一个矢量化方法来解决它 -

K = 2 # Select top K values along each column

# Sort A, store the argsort for later usage
sidx = np.argsort(A,axis=0)
sA = A[sidx,np.arange(A.shape[1])]

# Perform differentiation along rows and look for non-zero differentiations
df = np.diff(sA,axis=0)!=0

# Perform cumulative summation along rows from bottom upwards.
# Thus, summations < K should give us a mask of valid ones that are to
# be kept per column. Use this mask to set rest as zeros in sorted array.
mask = (df[::-1].cumsum(0)<K)[::-1]
sA[:-1] *=mask

# Finally revert back to unsorted order by using sorted indices sidx
out = sA[sidx.argsort(0),np.arange(sA.shape[1])]

请注意,为了获得更多性能提升,np.argsort 可以替换为 np.argpartition

示例输入、输出 -

In [343]: A
Out[343]:
array([[106, 106, 102],
[105, 101, 104],
[106, 107, 101],
[107, 103, 106],
[106, 105, 108],
[106, 104, 105],
[107, 101, 101],
[105, 103, 102],
[104, 102, 106],
[104, 106, 101]])

In [344]: out
Out[344]:
array([[106, 106, 0],
[ 0, 0, 0],
[106, 107, 0],
[107, 0, 106],
[106, 0, 108],
[106, 0, 0],
[107, 0, 0],
[ 0, 0, 0],
[ 0, 0, 106],
[ 0, 106, 0]])

关于python - 选择 NumPy 数组中每列中出现的所有前 K 值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/37930431/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com