gpt4 book ai didi

python - 在 numpy 中对分区索引分组 argmax/argmin

转载 作者:太空宇宙 更新时间:2023-11-03 11:52:44 28 4
gpt4 key购买 nike

Numpy 的 ufunc我们有一个 reduceat 在数组中的连续分区上运行它们的方法。所以不要写:

import numpy as np
a = np.array([4, 0, 6, 8, 0, 9, 8, 5, 4, 9])
split_at = [4, 5]
maxima = [max(subarray for subarray in np.split(a, split_at)]

我会写:

maxima = np.maximum.reduceat(a, np.hstack([0, split_at]))

两者都将返回切片中的最大值 a[0:4] , a[4:5] , a[5:10] , 是 [8, 0, 9] .

我想要一个类似的函数来执行 argmax ,注意到我只想在每个分区中有一个单个最大索引:[3, 4, 5]与上述 asplit_at (尽管索引 5 和 9 都获得了最后一组中的最大值),正如

np.hstack([0, split_at]) + [np.argmax(subarray) for subarray in np.split(a, split_at)]

我将在下面发布一个可能的解决方案,但我希望看到一个无需在组上创建索引即可矢量化的解决方案。

最佳答案

此解决方案涉及在组(上例中的 [0, 0, 0, 0, 1, 2, 2, 2, 2, 2])上构建索引。

group_lengths = np.diff(np.hstack([0, split_at, len(a)]))
n_groups = len(group_lengths)
index = np.repeat(np.arange(n_groups), group_lengths)

然后我们可以使用:

maxima = np.maximum.reduceat(a, np.hstack([0, split_at]))
all_argmax = np.flatnonzero(np.repeat(maxima, group_lengths) == a)
result = np.empty(len(group_lengths), dtype='i')
result[index[all_argmax[::-1]]] = all_argmax[::-1]

result 中得到 [3, 4, 5][::-1] 确保我们得到每个组中的 第一个 而不是最后一个 argmax。

这依赖于花式赋值中的最后一个索引决定赋值的事实,@seberg says one shouldn't rely on (使用 result = all_argmax[np.unique(index[all_argmax], return_index=True)[1]] 可以实现更安全的替代方案,它涉及对 len(maxima) 的排序~ n_groups 个元素)。

关于python - 在 numpy 中对分区索引分组 argmax/argmin,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/22124332/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com