gpt4 book ai didi

python - 3-d numpy 数组的 2 轴上的 argmax

转载 作者:太空宇宙 更新时间:2023-11-04 05:14:57 25 4
gpt4 key购买 nike

我想从 3D 矩阵中获取一维索引数组。

例如给定 x = np.random.randint(10, size=(10,3,3)),我想做类似 np.argmax(x , axis=(1,2)) 就像你可以用 np.max 做的那样,即获得一个长度为 10 的一维数组,其中包含的索引(0 到 8)大小为 (3,3) 的每个子矩阵的最大值。

到目前为止我还没有发现任何有用的东西,我想避免在第一个维度上循环(并使用 np.argmax(x)),因为它很大。

干杯!

最佳答案

reshape 以合并最后两个轴,然后使用 np.argmax -

idx = x.reshape(x.shape[0],-1).argmax(-1)
out = np.unravel_index(idx, x.shape[-2:])

sample 运行-

In [263]: x = np.random.randint(10, size=(4,3,3))

In [264]: x
Out[264]:
array([[[0, 9, 2],
[7, 7, 8],
[2, 5, 9]],

[[1, 7, 2],
[8, 9, 0],
[2, 8, 3]],

[[7, 5, 0],
[7, 1, 6],
[5, 1, 1]],

[[0, 7, 3],
[5, 4, 1],
[9, 8, 9]]])

In [265]: idx = x.reshape(x.shape[0],-1).argmax(-1)

In [266]: np.unravel_index(idx, x.shape[-2:])
Out[266]: (array([0, 1, 0, 2]), array([1, 1, 0, 0]))

如果你的意思是获取合并索引,那么它更简单 -

x.reshape(x.shape[0],-1).argmax(1)

sample 运行-

In [283]: x
Out[283]:
array([[[2, 3, 7],
[8, 1, 0],
[3, 6, 9]],

[[8, 0, 5],
[2, 2, 9],
[9, 0, 9]],

[[1, 9, 2],
[5, 0, 3],
[7, 2, 1]],

[[1, 6, 5],
[2, 3, 7],
[7, 4, 6]]])

In [284]: x.reshape(x.shape[0],-1).argmax(1)
Out[284]: array([8, 5, 1, 5])

关于python - 3-d numpy 数组的 2 轴上的 argmax,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41990871/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com