gpt4 book ai didi

arrays - 多维np.argmax?

转载 作者:行者123 更新时间:2023-12-03 18:27:44 25 4
gpt4 key购买 nike

我有一个形状为 (n, n, g) 的 3D 数组,我需要每个 (n, n) argmax,即结果应该是两个长度为 g 的索引向量 (x, y)。

直观的解决方案是:

array = np.random.uniform(size=[5, 5, 1000])
np.argmax(array, axis=[0, 1])

但是,numpy 不支持多个轴作为参数。

有没有办法得到这个结果?

最佳答案

您似乎希望为 2D 的每个展平版本获得二维(行、列)argmax 索引。切片与前两个轴合并/组合说话。所以,第一个区块是 array[:,:,0]依此类推,我们需要找到 argmax,该切片被展平并返回到原始 2D 形状。因此,为了解决这个问题,我们可以简单地 reshape 以合并前两个轴,沿着第一个轴获得 argmax,这是 reshape 后的合并轴,并使用 np.unravel_index 回溯原始索引。 ,像这样——

array2D = array.reshape(-1,array.shape[-1])
r,c = np.unravel_index(array2D.argmax(0),array.shape[:2])

sample 运行 -
In [29]: array = np.random.uniform(size=[5, 5, 1000])

In [30]: array2D = array.reshape(-1,array.shape[-1])

In [31]: r,c = np.unravel_index(array2D.argmax(0),array.shape[:2])

In [32]: len(r), len(c)
Out[32]: (1000, 1000)

让我们验证第一个 2D 的结果切片——
In [33]: array[:,:,0]
Out[33]:
array([[0.81590174, 0.17919069, 0.22717883, 0.67863625, 0.97390595],
[0.82096447, 0.05894774, 0.86379174, 0.13494354, 0.10003756],
[0.37243189, 0.33714008, 0.21165031, 0.35910642, 0.15163255],
[0.1376776 , 0.86866599, 0.43602004, 0.85421372, 0.77805012],
[0.10519547, 0.7422571 , 0.35632275, 0.24168307, 0.76882613]])

In [34]: array[:,:,0].argmax()
Out[34]: 4 # flattened index for 0.97390595 at (0,4) in the 2D slice

In [36]: r[0],c[0]
Out[36]: (0, 4)

关于arrays - 多维np.argmax?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51868834/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com