gpt4 book ai didi

Python:沿大于阈值的特定维度查找最大数组索引

转载 作者:行者123 更新时间:2023-11-28 21:43:56 24 4
gpt4 key购买 nike

假设我有一个 4 维 numpy 数组(例如:np.rand((x,y,z,t))),其数据维度对应于 X、Y、Z、和时间。

对于每个 X 和 Y 点,在每个时间步,我想找到 Z 中数据大于某个阈值 n 的最大索引。

所以我的最终结果应该是一个 X-by-Y-by-t 数组。 Z 列中没有大于阈值的值的实例应由 0 表示。

我可以逐个元素地循环并构建一个新数组,但是我在一个非常大的数组上操作并且需要很长时间。

最佳答案

不幸的是,按照 Python 内置函数的示例,numpy 并不容易获得 last 索引,尽管 first 是微不足道的。仍然,像

def slow(arr, axis, threshold):
return (arr > threshold).cumsum(axis=axis).argmax(axis=axis)

def fast(arr, axis, threshold):
compare = (arr > threshold)
reordered = compare.swapaxes(axis, -1)
flipped = reordered[..., ::-1]
first_above = flipped.argmax(axis=-1)
last_above = flipped.shape[-1] - first_above - 1
are_any_above = compare.any(axis=axis)
# patch the no-matching-element found values
patched = np.where(are_any_above, last_above, 0)
return patched

给我

In [14]: arr = np.random.random((100,100,30,50))

In [15]: %timeit a = slow(arr, axis=2, threshold=0.75)
1 loop, best of 3: 248 ms per loop

In [16]: %timeit b = fast(arr, axis=2, threshold=0.75)
10 loops, best of 3: 50.9 ms per loop

In [17]: (slow(arr, axis=2, threshold=0.75) == fast(arr, axis=2, threshold=0.75)).all()
Out[17]: True

(可能有一种更灵活的方式来进行翻转,但现在已经结束了,我的大脑正在关闭。:-)

关于Python:沿大于阈值的特定维度查找最大数组索引,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41515201/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com