gpt4 book ai didi

python - Cython 不够快

转载 作者:太空宇宙 更新时间:2023-11-04 08:04:23 25 4
gpt4 key购买 nike

我在 cython 中重写了我的 python 循环,期望速度有很大的提高。我只得到大约四分之一。难道我做错了什么?这是没有 cython 的代码:

import numpy as np
import itertools as itr
import math

def Pk (b, f, mu, k): # k is in Mpc
isoPk = 200*math.exp(-(k-0.02)**2/2/0.01**2) # Isotropic power spectrum
power = (b+mu**2*f)**2*isoPk
return power

def Gendk (N, kvec, Pk, b, f, deltak3d):
Nhalf = int(N/2)
for xx, yy, zz in itr.product(range(0,N), range(0,N), range(0,Nhalf+1)):
kx = kvec[xx]
ky = kvec[yy]
kz = kvec[zz]
kk = math.sqrt(kx**2+ky**2+kz**2)
if kk == 0:
continue
mu = kz/kk
power = Pk(b, f, mu, kk)
if power==0:
deltaRe = 0
deltaIm = 0
else:
deltaRe = np.random.normal(0, power/2.0)
if (xx==0 or xx==Nhalf) and (yy==0 or yy==Nhalf) and (zz==0 or zz==Nhalf):
deltaIm = 0
else:
deltaIm = np.random.normal(0, power/2.0)
x_conj = (2*N-xx)%N
y_conj = (2*N-yy)%N
z_conj = (2*N-zz)%N
deltak3d[xx,yy,zz] = deltaRe + deltaIm*1j
deltak3d[x_conj,y_conj,z_conj] = deltaRe - deltaIm*1j

Ntot = 300000
L = 1000
N = 128
Nhalf = int(N/2)
kmax = 5.0
dk = kmax/N
kvec = np.fft.fftfreq(N, L/N)
dL = L/N
deltak3d = np.zeros((N,N,N), dtype=complex)
deltak3d[0,0,0] = Ntot
Gendk(N, kvec, Pk, 2, 1, deltak3d)

这是 cython 的版本:

import numpy as np
import pyximport; pyximport.install(setup_args={"include_dirs":np.get_include()})
import testGauss as tG

Ntot = 300000
L = 1000
N = 128
Nhalf = int(N/2)
kmax = 5.0
dk = kmax/N
kvec = np.fft.fftfreq(N, L/N)
dL = L/N
deltak3d = np.zeros((N,N,N), dtype=complex)
deltak3d[0,0,0] = Ntot
tG.Gendk(N, kvec, tG.Pk, 2, 1, deltak3d)

testGauss.pyx 文件是:

import math
import numpy as np
cimport numpy as np
import itertools as itr

def Pk (double b, double f, double mu, double k): # k is in Mpc
cdef double isoPk, power
isoPk = 200*math.exp(-(k-0.02)**2/2/0.01**2) # Isotropic power spectrum
power = (b+mu**2*f)**2*isoPk
return power

def Gendk (int N, np.ndarray[np.float64_t,ndim=1] kvec, Pk, double b, double f, np.ndarray[np.complex128_t,ndim=3] deltak3d):
cdef int Nhalf = int(N/2)
cdef int xx, yy, zz
cdef int x_conj, y_conj, z_conj
cdef double kx, ky, kz, kk
cdef mu
cdef power
cdef deltaRe, deltaIm
for xx, yy, zz in itr.product(range(0,N), range(0,N), range(0,Nhalf+1)):
kx = kvec[xx]
ky = kvec[yy]
kz = kvec[zz]
kk = math.sqrt(kx**2+ky**2+kz**2)
if kk == 0:
continue
mu = kz/kk
power = Pk(b, f, mu, kk)
if power==0:
deltaRe = 0
deltaIm = 0
else:
deltaRe = np.random.normal(0, power/2.0)
if (xx==0 or xx==Nhalf) and (yy==0 or yy==Nhalf) and (zz==0 or zz==Nhalf):
deltaIm = 0
else:
deltaIm = np.random.normal(0, power/2.0)
x_conj = (2*N-xx)%N
y_conj = (2*N-yy)%N
z_conj = (2*N-zz)%N
deltak3d[xx,yy,zz] = deltaRe + deltaIm*1j
deltak3d[x_conj,y_conj,z_conj] = deltaRe - deltaIm*1j

非常感谢您!

最佳答案

你可以通过替换来获得一些加速

import math

from libc cimport math

这将避免在执行 sqrt 和 exp 时调用 python 函数,将其替换为直接的 c 调用(这应该快得多)。

我也有点担心循环内对 np.random.normal 的调用,它每次都会增加合理的 Python 开销。在循环之前调用它以生成大量随机数(具有单个 python 调用的开销)然后用 0 覆盖它们(如果循环内不需要它们)可能会更快。

优化 Cython 的一般建议仍然适用:运行

cython -a your_file.pyx

查看 HTML,并担心突出显示为黄色的位(但前提是它们经常被调用)

关于python - Cython 不够快,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/34123589/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com