gpt4 book ai didi

python - 使用 Numba 使四个嵌套 for 循环更快

转载 作者:行者123 更新时间:2023-11-28 22:18:47 25 4
gpt4 key购买 nike

我对使用 Numba 有点陌生,但我明白了它的要点。我想知道是否有任何更高级的技巧可以使四个嵌套的 for 循环比我现在拥有的更快。特别是,我需要计算以下积分:

enter image description here

其中B是一个二维数组,S0和E是某些参数。我的代码如下:

import numpy as np
from numba import njit, double

def calc_gb_gauss_2d(b,s0,e,dx):
n,m=b.shape
norm = 1.0/(2*np.pi*s0**2)
gb = np.zeros((n,m))
for i in range(n):
for j in range(m):
for ii in range(n):
for jj in range(m):
gb[i,j]+=np.exp(-(((i-ii)*dx)**2+((j-jj)*dx)**2)/(2.0*(s0*(1.0+e*b[i,j]))**2))
gb[i,j]*=norm
return gb

calc_gb_gauss_2d_nb = njit(double[:, :](double[:, :],double,double,double))(calc_gb_gauss_2d)

对于大小为 256x256 的输入数组,计算速度为:

In [4]: a=random.random((256,256))

In [5]: %timeit calc_gb_gauss_2d_nb(a,0.1,1.0,0.5)
The slowest run took 8.46 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 3: 1min 1s per loop

纯Python和Numba计算速度对比给我这张图: enter image description here

有什么方法可以优化我的代码以获得更好的性能吗?

最佳答案

通过使用 numpy 和一些数学可以加速你的代码,因此它比当前的 numba 版本快一个数量级。我们还将看到,在改进的函数上使用 numba 会使它更快。

numba 经常被过度使用 - 通常可以编写非常高效的 numpy-only 代码 - 这里也是这种情况。

手头的 numpy 代码存在一个问题:不应访问单个元素,而应利用 numpy 的内置函数 - 它们在大多数情况下都尽可能快。只有在无法使用这些 numpy 函数时,才会使用 numba 或 cython。

然而,这里最大的问题是问题的表述。对于固定的ij,我们有下面的公式来计算(我稍微简化了一下):

 g[i,j]=sum_ii sum_jj exp(value_ii+value_jj)
=sum_ii sum_jj exp(value_ii)*exp(value_jj)
=sum_ii exp(value_ii) * sum_jj exp(value_jj)

要计算最后一个公式,我们需要 O(n+m) 操作,但对于第一个简单的公式 O(n*m) - 差别很大!

利用 numpy 功能的第一个版本可能类似于:

def calc_ead(b,s0,e,dx):
n,m=b.shape
norm = 1.0/(2*np.pi*s0**2)
gb = np.zeros((n,m))
vI=np.arange(n)
vJ=np.arange(m)
for i in range(n):
for j in range(m):
II=(i-vI)*dx
JJ=(j-vJ)*dx
denom=2.0*(s0*(1.0+e*b[i,j]))**2
expII=np.exp(-II*II/denom)
expJJ=np.exp(-JJ*JJ/denom)
gb[i,j]=norm*(expII.sum()*expJJ.sum())
return gb

现在,与最初的 numba 实现相比:

>>> a=np.random.random((256,256))

>>> print(calc_gb_gauss_2d_nb(a,0.1,1.0,0.5)[1,1])
15.9160709993
>>> %timeit -n1 -r1 calc_gb_gauss_2d_nb(a,0.1,1.0,0.5)
1min 6s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

现在是 numpy 函数:

>>> print(calc_ead(a,0.1,1.0,0.5)[1,1])
15.9160709993
>>> %timeit -n1 -r1 calc_ead(a,0.1,1.0,0.5)
1.8 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

有两个观察结果:

  1. 结果是一样的。
  2. numpy 版本快 37 倍,对于更大的问题,这种差异会变得更大。

显然,您可以利用 numba 来实现更大的加速。然而,在可能的情况下使用 numpy 功能仍然是一个好主意 - 令人惊讶的是,最简单的事情可以如此微妙 - 例如甚至 calculating a sum :

>>> nb_calc_ead = njit(double[:, :](double[:, :],double,double,double))(calc_ead)
>>>print(nb_calc_ead(a,0.1,1.0,0.5)[1,1])
15.9160709993
>>>%timeit -n1 -r1 nb_calc_ead(a,0.1,1.0,0.5)
587 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

还有一个因素3!

这个问题可以并行化,但要做到这一点并非易事。我便宜的尝试使用 explicit loop parallelization :

from numba import njit, prange
import math

@njit(parallel=True) #needed, so it is parallelized
def parallel_nb_calc_ead(b,s0,e,dx):
n,m=b.shape
norm = 1.0/(2*np.pi*s0**2)
gb = np.zeros((n,m))
vI=np.arange(n)
vJ=np.arange(m)
for i in prange(n): #outer loop = explicit prange-loop
for j in range(m):
denom=2.0*(s0*(1.0+e*b[i,j]))**2
expII=np.zeros((n,))
expJJ=np.zeros((m,))
for k in range(n):
II=(i-vI[k])*dx
expII[k]=math.exp(-II*II/denom)

for k in range(m):
JJ=(j-vJ[k])*dx
expJJ[k]=math.exp(-JJ*JJ/denom)
gb[i,j]=norm*(expII.sum()*expJJ.sum())
return gb

现在:

>>> print(parallel_nb_calc_ead(a,0.1,1.0,0.5)[1,1])
15.9160709993
>>> %timeit -n1 -r1 parallel_nb_calc_ead(a,0.1,1.0,0.5)
349 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

几乎意味着另一个因素 2(我的机器只有两个 CPU,取决于硬件,加速可能更大)。顺便说一下,我们的速度比原始版本快了将近 200 倍。

我打赌可以改进上面的代码,但我不会去那里。


列出与 calc_ead 比较的当前版本:

import numpy as np
from numba import njit, double

def calc_gb_gauss_2d(b,s0,e,dx):
n,m=b.shape
norm = 1.0/(2*np.pi*s0**2)
gb = np.zeros((n,m))
for i in range(n):
for j in range(m):
for ii in range(n):
for jj in range(m):
gb[i,j]+=np.exp(-(((i-ii)*dx)**2+((j-jj)*dx)**2)/(2.0*(s0*(1.0+e*b[i,j]))**2))
gb[i,j]*=norm
return gb

calc_gb_gauss_2d_nb = njit(double[:, :](double[:, :],double,double,double))(calc_gb_gauss_2d)

关于python - 使用 Numba 使四个嵌套 for 循环更快,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50252083/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com