gpt4 book ai didi

python - Blitz 代码产生不同的输出

转载 作者:行者123 更新时间:2023-12-01 05:45:41 25 4
gpt4 key购买 nike

我想使用 weave.blitz 来提高以下 numpy 代码的性能:

def fastIteration(self):
g = self.grid
nx,ny = g.ux.shape

uxold = g.old_ux
ux = g.ux
ux[0:,1:-1] = uxold[0:,1:-1] + ReI* (uxold[0:,2:] - 2*uxold[0:,1:-1] + uxold[0:,0:-2])

g.setBC()
g.old_ux = ux.copy()

在此代码中,g 是计算网格。其中由两个不同的字段ux和uxold组成。 old 只是用来临时存储变量。在完整的代码中,大约 95% 的运行时间花费在 fastIteration 方法中,因此即使是简单的性能提升也会显着减少执行此代码所花费的时间。

numpy 方法的输出如下所示:

numpy result

由于这段代码是我的瓶颈,我想通过使用 weave blitz 来提高速度。这个方法看起来像:

def blitzIteration(self):
### does not work correct so far
g = self.grid
nx,ny = g.ux.shape

uxold = g.old_ux
ux = g.ux
expr = "ux[0:,1:-1] = uxold[0:,1:-1] + ReI* (uxold[0:,2:] - 2*uxold[0:,1:-1] + uxold[0:,0:-2])"
weave.blitz(expr, check_size=0)
g.setBC()
g.old_ux = ux.copy()

但是这不会产生正确的输出: output for blitz code

最佳答案

它看起来像 weave.blitz 中的错误(已复制、归档和 fixed 。那里有关于实际错误的更多信息)。

我觉得写 0: 而不是更短的 : 来获得完整的切片很奇怪,所以我替换了所有这些切片,瞧,它起作用了。

我真的不知道错误在哪里,但是expr_codeweave.blitz 生成的略有不同:

  • 使用0时:

    ipdb> expr_code
    'ux_blitz_buggy(blitz::Range(0,_end),blitz::Range(1,Nux_blitz_buggy(1)-1-1))=uxold(blitz::Range(0,_end),blitz::Range(1,Nuxold(1)-1-1))+ReI*(uxold(blitz::Range(0,_end),blitz::Range(2,_end))-2*uxold(blitz::Range(0,_end),blitz::Range(1,Nuxold(1)-1-1))+uxold(blitz::Range(0,_end),blitz::Range(0,Nuxold(1)-2-1)));\n'
  • 使用时:

    ipdb> expr_code
    'ux_blitz_not_buggy(_all,blitz::Range(1,Nux_blitz_not_buggy(1)-1-1))=uxold(_all,blitz::Range(1,Nuxold(1)-1-1))+ReI*(uxold(_all,blitz::Range(2,_end))-2*uxold(_all,blitz::Range(1,Nuxold(1)-1-1))+uxold(_all,blitz::Range(0,Nuxold(1)-2-1)));\n'

因此,blitz::Range(0,_end) 变为 _all 并且它们的行为方式不同。

为了方便起见,这里有一个完整的脚本,可以重现问题,并且只有在问题存在时才会成功。

import numpy as np
from scipy.weave import blitz


def test_blitz_bug(N=4):
ReI = 1.2
ux_blitz_buggy, ux_blitz_not_buggy, ux_np = np.zeros((N, N)), np.zeros((N, N)), np.zeros((N, N))
uxold = np.random.randn(N, N)
ux_np[0:,1:-1] = uxold[0:,1:-1] + ReI* (uxold[0:,2:] - 2*uxold[0:,1:-1] + uxold[0:,0:-2])
expr_buggy = 'ux_blitz_buggy[0:,1:-1] = uxold[0:,1:-1] + ReI* (uxold[0:,2:] - 2*uxold[0:,1:-1] + uxold[0:,0:-2])'
expr_not_buggy = 'ux_blitz_not_buggy[:,1:-1] = uxold[:,1:-1] + ReI* (uxold[:,2:] - 2*uxold[:,1:-1] + uxold[:,0:-2])'
blitz(expr_buggy)
blitz(expr_not_buggy)
assert not np.allclose(ux_blitz_buggy, ux_np)
assert np.allclose(ux_blitz_not_buggy, ux_np)

if __name__ == '__main__':
test_blitz_bug()

关于python - Blitz 代码产生不同的输出,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/16230848/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com