gpt4 book ai didi

python - python中FFT的循环加速(使用 `np.einsum`)

转载 作者:行者123 更新时间:2023-12-03 23:08:10 26 4
gpt4 key购买 nike

问题:我想用 np.einsum 加速包含大量乘积和求和的 python 循环。 ,但我也愿意接受任何其他解决方案。

我的函数采用形状为 (n,n,3) 的向量配置 S(我的情况:n=72)并对 N*N 点的相关函数进行傅立叶变换。相关函数定义为每个向量与其他向量的乘积。这乘以向量位置乘以 kx 和 ky 值的余弦函数。各岗位i,j最后相加得到 k 空间中的一分 p,m :

def spin_spin(S,N):
n= len(S)
conf = np.reshape(S,(n**2,3))
chi = np.zeros((N,N))
kx = np.linspace(-5*np.pi/3,5*np.pi/3,N)
ky = np.linspace(-3*np.pi/np.sqrt(3),3*np.pi/np.sqrt(3),N)

x=np.reshape(triangular(n)[0],(n**2))
y=np.reshape(triangular(n)[1],(n**2))
for p in range(N):
for m in range(N):
for i in range(n**2):
for j in range(n**2):
chi[p,m] += 2/(n**2)*np.dot(conf[i],conf[j])*np.cos(kx[p]*(x[i]-x[j])+ ky[m]*(y[i]-y[j]))
return(chi,kx,ky)

我的问题是我需要大约 100*100 个点,由 kx*ky 表示,并且循环需要很多小时才能完成具有 72*72 向量的格子的这项工作。
计算数量:72*72*72*72*100*100
我无法使用 numpy 的内置 FFT ,因为我的三角形网格,所以我需要一些其他的选择来减少这里的计算成本。

我的想法:首先,我意识到将配置 reshape 为向量列表而不是矩阵可以降低计算成本。另外我用了numba包,也降低了成本,但是还是太慢了。我发现计算这类对象的一个​​好方法是 np.einsum功能。计算每个向量与每个向量的乘积是通过以下方式完成的:
np.einsum('ij,kj -> ik',np.reshape(S,(72**2,3)),np.reshape(S,(72**2,3)))

棘手的部分是计算 np.cos 中的项。 .在这里,我想计算形状列表 (100,1) 与向量位置之间的乘积(例如 np.shape(x)=(72**2,1) )。特别是我真的不知道如何用 np.einsum 实现x 方向和y 方向的距离。 .

重现代码(可能你不需要这个):首先你需要一个矢量配置。您可以简单地使用 np.ones((72,72,3)或者您以随机向量为例:
def spherical_to_cartesian(r, theta, phi):
'''Convert spherical coordinates (physics convention) to cartesian coordinates'''
sin_theta = np.sin(theta)
x = r * sin_theta * np.cos(phi)
y = r * sin_theta * np.sin(phi)
z = r * np.cos(theta)

return x, y, z # return a tuple

def random_directions(n, r):
'''Return ``n`` 3-vectors in random directions with radius ``r``'''
out = np.empty(shape=(n,3), dtype=np.float64)

for i in range(n):
# Pick directions randomly in solid angle
phi = random.uniform(0, 2*np.pi)
theta = np.arccos(random.uniform(-1, 1))
# unpack a tuple
x, y, z = spherical_to_cartesian(r, theta, phi)
out[i] = x, y, z

return out
S = np.reshape(random_directions(72**2,1),(72,72,3))

(本例中的 reshape 需要在函数 spin_spin 中将其整形回 (72**2,3) 形状。)

对于向量的位置,我使用由以下定义的三角形网格
def triangular(nsize):
'''Positional arguments of the spin configuration'''

X=np.zeros((nsize,nsize))
Y=np.zeros((nsize,nsize))
for i in range(nsize):
for j in range(nsize):
X[i,j]+=1/2*j+i
Y[i,j]+=np.sqrt(3)/2*j
return(X,Y)

最佳答案

优化 Numba 实现

您代码中的主要问题是调用外部 BLAS 函数 np.dot反复极数据。在这段代码中,只计算一次会更有意义,但是如果您必须在循环中进行计算,请编写一个 Numba 实现。 Example

优化功能(蛮力)

import numpy as np
import numba as nb

@nb.njit(fastmath=True,error_model="numpy",parallel=True)
def spin_spin(S,N):
n= len(S)
conf = np.reshape(S,(n**2,3))
chi = np.zeros((N,N))
kx = np.linspace(-5*np.pi/3,5*np.pi/3,N).astype(np.float32)
ky = np.linspace(-3*np.pi/np.sqrt(3),3*np.pi/np.sqrt(3),N).astype(np.float32)

x=np.reshape(triangular(n)[0],(n**2)).astype(np.float32)
y=np.reshape(triangular(n)[1],(n**2)).astype(np.float32)

#precalc some values
fact=nb.float32(2/(n**2))
conf_dot=np.dot(conf,conf.T).astype(np.float32)

for p in nb.prange(N):
for m in range(N):
#accumulating on a scalar is often beneficial
acc=nb.float32(0)
for i in range(n**2):
for j in range(n**2):
acc+= conf_dot[i,j]*np.cos(kx[p]*(x[i]-x[j])+ ky[m]*(y[i]-y[j]))
chi[p,m]=fact*acc

return(chi,kx,ky)

优化功能(去除冗余计算)

做了很多多余的计算。这是有关如何删除它们的示例。这也是一个以 double 进行计算的版本。
@nb.njit()
def precalc(S):
#There may not be all redundancies removed
n= len(S)
conf = np.reshape(S,(n**2,3))
conf_dot=np.dot(conf,conf.T)
x=np.reshape(triangular(n)[0],(n**2))
y=np.reshape(triangular(n)[1],(n**2))

x_s=set()
y_s=set()
for i in range(n**2):
for j in range(n**2):
x_s.add((x[i]-x[j]))
y_s.add((y[i]-y[j]))

x_arr=np.sort(np.array(list(x_s)))
y_arr=np.sort(np.array(list(y_s)))


conf_dot_sel=np.zeros((x_arr.shape[0],y_arr.shape[0]))
for i in range(n**2):
for j in range(n**2):
ii=np.searchsorted(x_arr,x[i]-x[j])
jj=np.searchsorted(y_arr,y[i]-y[j])
conf_dot_sel[ii,jj]+=conf_dot[i,j]

return x_arr,y_arr,conf_dot_sel

@nb.njit(fastmath=True,error_model="numpy",parallel=True)
def spin_spin_opt_2(S,N):
chi = np.empty((N,N))
n= len(S)

kx = np.linspace(-5*np.pi/3,5*np.pi/3,N)
ky = np.linspace(-3*np.pi/np.sqrt(3),3*np.pi/np.sqrt(3),N)

x_arr,y_arr,conf_dot_sel=precalc(S)
fact=2/(n**2)
for p in nb.prange(N):
for m in range(N):
acc=nb.float32(0)
for i in range(x_arr.shape[0]):
for j in range(y_arr.shape[0]):
acc+= fact*conf_dot_sel[i,j]*np.cos(kx[p]*x_arr[i]+ ky[m]*y_arr[j])
chi[p,m]=acc

return(chi,kx,ky)

@nb.njit()
def precalc(S):
#There may not be all redundancies removed
n= len(S)
conf = np.reshape(S,(n**2,3))
conf_dot=np.dot(conf,conf.T)
x=np.reshape(triangular(n)[0],(n**2))
y=np.reshape(triangular(n)[1],(n**2))

x_s=set()
y_s=set()
for i in range(n**2):
for j in range(n**2):
x_s.add((x[i]-x[j]))
y_s.add((y[i]-y[j]))

x_arr=np.sort(np.array(list(x_s)))
y_arr=np.sort(np.array(list(y_s)))


conf_dot_sel=np.zeros((x_arr.shape[0],y_arr.shape[0]))
for i in range(n**2):
for j in range(n**2):
ii=np.searchsorted(x_arr,x[i]-x[j])
jj=np.searchsorted(y_arr,y[i]-y[j])
conf_dot_sel[ii,jj]+=conf_dot[i,j]

return x_arr,y_arr,conf_dot_sel

@nb.njit(fastmath=True,error_model="numpy",parallel=True)
def spin_spin_opt_2(S,N):
chi = np.empty((N,N))
n= len(S)

kx = np.linspace(-5*np.pi/3,5*np.pi/3,N)
ky = np.linspace(-3*np.pi/np.sqrt(3),3*np.pi/np.sqrt(3),N)

x_arr,y_arr,conf_dot_sel=precalc(S)
fact=2/(n**2)
for p in nb.prange(N):
for m in range(N):
acc=nb.float32(0)
for i in range(x_arr.shape[0]):
for j in range(y_arr.shape[0]):
acc+= fact*conf_dot_sel[i,j]*np.cos(kx[p]*x_arr[i]+ ky[m]*y_arr[j])
chi[p,m]=acc

return(chi,kx,ky)

计时
#brute-force
%timeit res=spin_spin(S,100)
#48 s ± 671 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

#new version
%timeit res_2=spin_spin_opt_2(S,100)
#5.33 s ± 59.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit res_2=spin_spin_opt_2(S,1000)
#1min 23s ± 2.43 s per loop (mean ± std. dev. of 7 runs, 1 loop each)

编辑(SVML 检查)
import numba as nb
import numpy as np

@nb.njit(fastmath=True)
def foo(n):
x = np.empty(n*8, dtype=np.float64)
ret = np.empty_like(x)
for i in range(ret.size):
ret[i] += np.cos(x[i])
return ret

foo(1000)

if 'intel_svmlcc' in foo.inspect_llvm(foo.signatures[0]):
print("found")
else:
print("not found")

#found

如果有 not found阅读 this link.它应该可以在 Linux 和 Windows 上运行,但我还没有在 macOS 上测试过。

关于python - python中FFT的循环加速(使用 `np.einsum`),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60934744/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com