gpt4 book ai didi

python - 替代 numpy.argwhere 以加速 python 中的循环

转载 作者:行者123 更新时间:2023-11-28 21:03:04 24 4
gpt4 key购买 nike

我有两个数据集如下:

ds1:作为 2d numpy 数组的 DEM(数字高程模型)文件,

ds2:显示其中有一些多余水分的区域(像素)。

我有一个 while 循环,它负责根据其 8 个相邻像素及其自身的高度散布(和更改)每个像素中的多余体积,直到每个像素中的多余体积小于某个值 d = 0.05。因此,在每次迭代中,我需要在 ds2 中找到超出体积大于 0.05 的像素索引,如果没有剩余像素,则退出 while 循环:

exit_code == "No"
while exit_code == "No":
index_of_pixels_with_excess_volume = numpy.argwhere(ds2> 0.05) # find location of pixels where excess volume is greater than 0.05

if not index_of_pixels_with_excess_volume.size:
exit_code = "Yes"
else:
for pixel in index_of_pixels_with_excess_volume:
# spread those excess volumes to the neighbours and
# change the values of ds2

问题是 numpy.argwhere(ds2> 0.05) 非常慢。我正在寻找更快的替代解决方案。

最佳答案

制作一个二维数组示例:

In [584]: arr = np.random.rand(1000,1000)

找出其中的一小部分:

In [587]: np.where(arr>.999)
Out[587]:
(array([ 1, 1, 1, ..., 997, 999, 999], dtype=int32),
array([273, 471, 584, ..., 745, 310, 679], dtype=int32))
In [588]: _[0].shape
Out[588]: (1034,)

时间 argwhere 的各个部分:

In [589]: timeit arr>.999
2.65 ms ± 116 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [590]: timeit np.count_nonzero(arr>.999)
2.79 ms ± 26 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [591]: timeit np.nonzero(arr>.999)
6 ms ± 10 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [592]: timeit np.argwhere(arr>.999)
6.06 ms ± 58.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

所以大约 1/3 的时间花在了 测试上,剩下的时间花在寻找 True 元素上。将 where 元组转换为 2 列数组的速度很快。

现在如果目标只是找到第一个 > 值,argmax 很快。

In [593]: np.argmax(arr>.999)
Out[593]: 1273 # can unravel this to (1,273)
In [594]: timeit np.argmax(arr>.999)
2.76 ms ± 143 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

argmax 短路,因此实际运行时间会随着找到第一个值而变化。

flatnonzerowhere 更快:

In [595]: np.flatnonzero(arr>.999)
Out[595]: array([ 1273, 1471, 1584, ..., 997745, 999310, 999679], dtype=int32)
In [596]: timeit np.flatnonzero(arr>.999)
3.05 ms ± 26.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [599]: np.unravel_index(np.flatnonzero(arr>.999),arr.shape)
Out[599]:
(array([ 1, 1, 1, ..., 997, 999, 999], dtype=int32),
array([273, 471, 584, ..., 745, 310, 679], dtype=int32))
In [600]: timeit np.unravel_index(np.flatnonzero(arr>.999),arr.shape)
3.05 ms ± 3.58 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [601]: timeit np.transpose(np.unravel_index(np.flatnonzero(arr>.999),arr.shap
...: e))
3.1 ms ± 5.86 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

这与 np.argwhere(arr>.999) 相同。

有趣的是,flatnonzero 方法将时间缩短了一半!没想到会有这么大的进步。


比较迭代速度:

argwhere 对二维数组进行迭代:

In [607]: pixels = np.argwhere(arr>.999)
In [608]: timeit [pixel for pixel in pixels]
347 µs ± 5.29 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

使用 zip(*) 转置从 where 迭代元组:

In [609]: idx = np.where(arr>.999)
In [610]: timeit [pixel for pixel in zip(*idx)]
256 µs ± 147 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)

在数组上迭代通常比在列表上迭代要慢一点,或者在这种情况下是压缩数组。

In [611]: [pixel for pixel in pixels][:5]
Out[611]:
[array([ 1, 273], dtype=int32),
array([ 1, 471], dtype=int32),
array([ 1, 584], dtype=int32),
array([ 1, 826], dtype=int32),
array([ 2, 169], dtype=int32)]
In [612]: [pixel for pixel in zip(*idx)][:5]
Out[612]: [(1, 273), (1, 471), (1, 584), (1, 826), (2, 169)]

一个是数组列表,另一个是元组列表。但是将这些元组(单独)转换为数组很慢:

In [614]: timeit [np.array(pixel) for pixel in zip(*idx)]
2.26 ms ± 4.94 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

在平面非零数组上迭代更快

In [617]: fdx = np.flatnonzero(arr>.999)
In [618]: fdx[:5]
Out[618]: array([1273, 1471, 1584, 1826, 2169], dtype=int32)
In [619]: timeit [i for i in fdx]
112 µs ± 23.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

但是对这些值单独应用unravel 需要时间。

def foo(idx):    # a simplified unravel
return idx//1000, idx%1000

In [628]: timeit [foo(i) for i in fdx]
1.12 ms ± 1.02 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

将这 1 毫秒添加到 3 毫秒以生成 fdx,此 flatnonzero 可能仍会领先。但在最好的情况下,我们谈论的是 2 倍的速度提升。

关于python - 替代 numpy.argwhere 以加速 python 中的循环,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47068017/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com