gpt4 book ai didi

python - 优化包括多个 np.multipy 语句的代码片段

转载 作者:行者123 更新时间:2023-12-02 10:56:11 29 4
gpt4 key购买 nike

我致力于优化 python 模块中的一些代码。我已经确定了瓶颈,并且是一个在 numpy 中进行一些计算的代码片段。即如下代码:

    xh = np.multiply(K_Rinv[0, 0], x )
xh += np.multiply(K_Rinv[0, 1], y)
xh += np.multiply(K_Rinv[0, 2], h)
yh = np.multiply(K_Rinv[1, 0], x)
yh += np.multiply(K_Rinv[1, 1], y)
yh += np.multiply(K_Rinv[1, 2], h)
q = np.multiply(K_Rinv[2, 0], x)
q += np.multiply(K_Rinv[2, 1], y)
q += np.multiply(K_Rinv[2, 2], h)

其中 x、y 和 h 是形状为 (4206,5749) 的 np.ndarray,K_Rinv 是形状为 (3,3) 的 np.ndarray )。该代码片段被多次调用,占用了整个代码50%以上的时间。有没有办法加快速度?还是就这样,无法加速。

编辑1:
感谢您的回答。在使用 numba 遇到问题后(请参阅最后的错误消息),我尝试了 numexpr 的建议。然而,使用这个解决方案时我的代码崩溃了。所以我检查了结果是否相同,是否不同。这是我正在使用的代码:

    xh_1 = numexpr.evaluate('a1*b1+a2*b2+a3*b3', {'a1': K_Rinv[0, 0], 'b1': x,
'a2': K_Rinv[0, 1], 'b2': y,
'a3': K_Rinv[0, 2], 'b3': h})
yh_1 = numexpr.evaluate('a1*b1+a2*b2+a3*b3', {'a1': K_Rinv[1, 0], 'b1': x,
'a2': K_Rinv[1, 1], 'b2': y,
'a3': K_Rinv[1, 2], 'b3': h})
q_1 = numexpr.evaluate('a1*b1+a2*b2+a3*b3', {'a1': K_Rinv[2, 0], 'b1': x,
'a2': K_Rinv[2, 1], 'b2': y,
'a3': K_Rinv[2, 2], 'b3': h})
xh_2 = np.multiply(K_Rinv[0, 0], x )
xh_2 += np.multiply(K_Rinv[0, 1], y)
xh_2 += np.multiply(K_Rinv[0, 2], h)
yh_2 = np.multiply(K_Rinv[1, 0], x)
yh_2 += np.multiply(K_Rinv[1, 1], y)
yh_2 += np.multiply(K_Rinv[1, 2], h)
q_2 = np.multiply(K_Rinv[2, 0], x)
q_2 += np.multiply(K_Rinv[2, 1], y)
q_2 += np.multiply(K_Rinv[2, 2], h)
check1 = xh_1.all() == xh_2.all()
check2 = yh_1.all() == yh_2.all()
check3 = q_2.all() == q_1.all()
print ( " Check1 :{} , Check2: {} , Check3:{}" .format (check1,check2,check3))

我没有任何使用 numexpr 的经验,它们通常不一样吗?

numba 错误:

 File "/usr/local/lib/python3.6/dist-packages/numba/dispatcher.py", line 420, in _compile_for_args
raise e
File "/usr/local/lib/python3.6/dist-packages/numba/dispatcher.py", line 353, in _compile_for_args
return self.compile(tuple(argtypes))
File "/usr/local/lib/python3.6/dist-packages/numba/compiler_lock.py", line 32, in _acquire_compile_lock
return func(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/numba/dispatcher.py", line 768, in compile
cres = self._compiler.compile(args, return_type)
File "/usr/local/lib/python3.6/dist-packages/numba/dispatcher.py", line 77, in compile
status, retval = self._compile_cached(args, return_type)
File "/usr/local/lib/python3.6/dist-packages/numba/dispatcher.py", line 91, in _compile_cached
retval = self._compile_core(args, return_type)
File "/usr/local/lib/python3.6/dist-packages/numba/dispatcher.py", line 109, in _compile_core
pipeline_class=self.pipeline_class)
File "/usr/local/lib/python3.6/dist-packages/numba/compiler.py", line 551, in compile_extra
return pipeline.compile_extra(func)
File "/usr/local/lib/python3.6/dist-packages/numba/compiler.py", line 327, in compile_extra
raise e
File "/usr/local/lib/python3.6/dist-packages/numba/compiler.py", line 321, in compile_extra
ExtractByteCode().run_pass(self.state)
File "/usr/local/lib/python3.6/dist-packages/numba/untyped_passes.py", line 67, in run_pass
bc = bytecode.ByteCode(func_id)
File "/usr/local/lib/python3.6/dist-packages/numba/bytecode.py", line 215, in __init__
self._compute_lineno(table, code)
File "/usr/local/lib/python3.6/dist-packages/numba/bytecode.py", line 237, in _compute_lineno
known = table[_FIXED_OFFSET].lineno
KeyError: 2

编辑2坦克的评论和答案。因此,在再次检查代码后,numexpr 解决方案有效。非常感谢。我仍然在一个单独的 python 文件中进行了一些测试,看看 numba 代码是否可以在那里工作,而且确实可以工作,但速度非常慢。请参阅下面我使用的代码:

import numpy as np
import numba as nb
import numexpr
from datetime import datetime

def calc(x,y,h,K_Rinv):
xh_2 = np.multiply(K_Rinv[0, 0], x )
xh_2 += np.multiply(K_Rinv[0, 1], y)
xh_2 += np.multiply(K_Rinv[0, 2], h)
yh_2 = np.multiply(K_Rinv[1, 0], x)
yh_2 += np.multiply(K_Rinv[1, 1], y)
yh_2 += np.multiply(K_Rinv[1, 2], h)
q_2 = np.multiply(K_Rinv[2, 0], x)
q_2 += np.multiply(K_Rinv[2, 1], y)
q_2 += np.multiply(K_Rinv[2, 2], h)
return xh_2, yh_2, q_2

def calc_numexpr(x,y,h,K_Rinv):
xh = numexpr.evaluate('a1*b1+a2*b2+a3*b3', {'a1': K_Rinv[0, 0], 'b1': x,
'a2': K_Rinv[0, 1], 'b2': y,
'a3': K_Rinv[0, 2], 'b3': h})
yh = numexpr.evaluate('a1*b1+a2*b2+a3*b3', {'a1': K_Rinv[1, 0], 'b1': x,
'a2': K_Rinv[1, 1], 'b2': y,
'a3': K_Rinv[1, 2], 'b3': h})
q = numexpr.evaluate('a1*b1+a2*b2+a3*b3', {'a1': K_Rinv[2, 0], 'b1': x,
'a2': K_Rinv[2, 1], 'b2': y,
'a3': K_Rinv[2, 2], 'b3': h})
return xh,yh,q


@nb.njit(fastmath=True,parallel=True)
def calc_nb(x,y,h,K_Rinv):
xh=np.empty_like(x)
yh=np.empty_like(x)
q=np.empty_like(x)

for i in nb.prange(x.shape[0]):
for j in range(x.shape[1]):
xh[i,j]=K_Rinv[0, 0]*x[i,j]+K_Rinv[0, 1]* y[i,j]+K_Rinv[0, 2]*h[i,j]
yh[i,j]=K_Rinv[1, 0]*x[i,j]+K_Rinv[1, 1]* y[i,j]+K_Rinv[1, 2]*h[i,j]
q[i,j] =K_Rinv[2, 0]*x[i,j]+K_Rinv[2, 1]* y[i,j]+K_Rinv[2, 2]*h[i,j]
return xh,yh,q


x = np.random.random((4206, 5749))
y = np.random.random((4206, 5749))
h = np.random.random((4206, 5749))
K_Rinv = np.random.random((3, 3))

start = datetime.now()
x_calc,y_calc,q_calc = calc(x,y,h,K_Rinv)
end = datetime.now()
print("Calc took: {} ".format(end - start))

start = datetime.now()
x_numexpr,y_numexpr,q_numexpr = calc_numexpr(x,y,h,K_Rinv)
end = datetime.now()
print("Calc_numexpr took: {} ".format(end - start))

start = datetime.now()
x_nb,y_nb,q_nb = calc_nb(x,y,h,K_Rinv)
end = datetime.now()
print("Calc nb took: {} ".format(end - start))

check_nb_q = (q_calc==q_nb).all()
check_nb_y = (y_calc==y_nb).all()
check_nb_x = (x_calc==x_nb).all()

check_numexpr_q = (q_calc==q_numexpr).all()
check_numexpr_y = (y_calc==y_numexpr).all()
check_numexpr_x = (x_calc==x_numexpr).all()

print("Checks for numexpr: {} , {} ,{} \nChecks for nb: {} ,{}, {}" .format(check_numexpr_x,check_numexpr_y,check_numexpr_q,check_nb_x,check_nb_y,check_nb_q))

输出以下内容:

Calc took:           0:00:01.944150 
Calc_numexpr took: 0:00:00.616224
Calc nb took: 0:00:01.553058
Checks for numexpr: True , True ,True
Checks for nb: False ,False, False

所以 numba 版本无法按预期工作。知道我做错了什么吗?希望 numba 解决方案也能正常工作。

诗。注意。版本是“0.47.0”

最佳答案

另一种可能性是使用 Numba。

示例

import numpy as np
import numba as nb

@nb.njit(fastmath=True,parallel=True)
def calc_nb(x,y,h,K_Rinv):
xh=np.empty_like(x)
yh=np.empty_like(x)
q=np.empty_like(x)

for i in nb.prange(x.shape[0]):
for j in range(x.shape[1]):
xh[i,j]=K_Rinv[0, 0]*x[i,j]+K_Rinv[0, 1]* y[i,j]+K_Rinv[0, 2]*h[i,j]
yh[i,j]=K_Rinv[1, 0]*x[i,j]+K_Rinv[1, 1]* y[i,j]+K_Rinv[1, 2]*h[i,j]
q[i,j] =K_Rinv[2, 0]*x[i,j]+K_Rinv[2, 1]* y[i,j]+K_Rinv[2, 2]*h[i,j]
return xh,yh,q

此计算内存带宽是否受到限制?

def copy(x,y,h,K_Rinv):
return np.copy(x),np.copy(y),np.copy(h)

%timeit copy(x,y,h,K_Rinv)
#147 ms ± 4.98 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

此计算受到内存带宽和动态内存分配的限制,两者之间的乘法与性能无关。

时间

x = np.random.random((4206, 5749))
y = np.random.random((4206, 5749))
h = np.random.random((4206, 5749))
K_Rinv = np.random.random((3, 3))

%timeit calc(x,y,h,K_Rinv) #Your implementation
#581 ms ± 8.05 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit calc_nb(x,y,h,K_Rinv)
#145 ms ± 3.81 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit calc_numexpr_scleronomic(x,y,h,K_Rinv)
#175 ms ± 1.83 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit calc_Daniel_F(x,y,h,K_Rinv)
#589 ms ± 24.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

可能的进一步优化:重用已分配的内存

@nb.njit(fastmath=True,parallel=True)
def calc_nb_2(x,y,h,K_Rinv,xh,yh,q):
for i in nb.prange(x.shape[0]):
for j in range(x.shape[1]):
xh[i,j]=K_Rinv[0, 0]*x[i,j]+K_Rinv[0, 1]* y[i,j]+K_Rinv[0, 2]*h[i,j]
yh[i,j]=K_Rinv[1, 0]*x[i,j]+K_Rinv[1, 1]* y[i,j]+K_Rinv[1, 2]*h[i,j]
q[i,j] =K_Rinv[2, 0]*x[i,j]+K_Rinv[2, 1]* y[i,j]+K_Rinv[2, 2]*h[i,j]
return xh,yh,q

#allocate memory only once if you call this function repeatedly
xh=np.empty_like(x)
yh=np.empty_like(x)
q=np.empty_like(x)

%timeit calc_nb_2(x,y,h,K_Rinv,xh,yh,q)
69.2 ms ± 194 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

关于python - 优化包括多个 np.multipy 语句的代码片段,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59677789/

29 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com