gpt4 book ai didi

python - Numpy 匹配索引维度

转载 作者:太空狗 更新时间:2023-10-30 00:59:06 32 4
gpt4 key购买 nike

问题

我有两个 numpy 数组,Aindices .

A尺寸为 m x n x 10000。 indices尺寸为 m x n x 5(来自 argpartition(A, 5)[:,:,:5] 的输出)。我想要一个包含 A 元素的 m x n x 5 数组对应indices .

尝试

indices = np.array([[[5,4,3,2,1],[1,1,1,1,1],[1,1,1,1,1]],
[500,400,300,200,100],[100,100,100,100,100],[100,100,100,100,100]])
A = np.reshape(range(2 * 3 * 10000), (2,3,10000))

A[...,indices] # gives an array of size (2,3,2,3,5). I want a subset of these values
np.take(A, indices) # shape is right, but it flattens the array first
np.choose(indices, A) # fails because of shape mismatch.

动机

我正在尝试获取 A[i,j] 的 5 个最大值对于每个 i<m , j<n使用 np.argpartition 排序因为数组会变得相当大。

最佳答案

您可以使用 advanced-indexing -

m,n = A.shape[:2]
out = A[np.arange(m)[:,None,None],np.arange(n)[:,None],indices]

sample 运行-

In [330]: A
Out[330]:
array([[[38, 21, 61, 74, 35, 29, 44, 46, 43, 38],
[22, 44, 89, 48, 97, 75, 50, 16, 28, 78],
[72, 90, 48, 88, 64, 30, 62, 89, 46, 20]],

[[81, 57, 18, 71, 43, 40, 57, 14, 89, 15],
[93, 47, 17, 24, 22, 87, 34, 29, 66, 20],
[95, 27, 76, 85, 52, 89, 69, 92, 14, 13]]])

In [331]: indices
Out[331]:
array([[[7, 8, 1],
[7, 4, 7],
[4, 8, 4]],

[[0, 7, 4],
[5, 3, 1],
[1, 4, 0]]])

In [332]: m,n = A.shape[:2]

In [333]: A[np.arange(m)[:,None,None],np.arange(n)[:,None],indices]
Out[333]:
array([[[46, 43, 21],
[16, 97, 16],
[64, 46, 64]],

[[81, 14, 43],
[87, 24, 47],
[27, 52, 95]]])

为了获得对应于最后一个轴上最多 5 个元素的索引,我们将使用 argpartition,就像这样 -

indices = np.argpartition(-A,5,axis=-1)[...,:5]

要保持从高到低的顺序,请使用 range(5) 而不是 5

关于python - Numpy 匹配索引维度,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45310816/

32 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com