gpt4 book ai didi

scipy - 有没有办法将 scipy.optimize.fsolve 与 jit_integrand_function 和 scipy.integrate.quad 一起使用?

转载 作者:行者123 更新时间:2023-12-05 07:04:28 27 4
gpt4 key购买 nike

基于此处提供的解释 1 ,我正在尝试使用相同的想法来加速以下积分:

import scipy.integrate as si
from scipy.optimize import root, fsolve
import numba
from numba import cfunc
from numba.types import intc, CPointer, float64
from scipy import LowLevelCallable

def integrand(t, *args):
a = args[0]
c = fsolve(lambda x: a * x**2 - np.exp(-x**2 / a), 1)[0]
return c * np.exp(- (t / (a * c))**2)

def do_integrate(func, a):
return si.quad(func, 0, 1, args=(a,))

print(do_integrate(integrand, 2.)[0])

根据前面的引用资料,我尝试使用numba/jit,并按照以下方式修改前面的 block :

import numpy as np
import scipy.integrate as si
from scipy.optimize import root
import numba
from numba import cfunc
from numba.types import intc, CPointer, float64
from scipy import LowLevelCallable

def jit_integrand_function(integrand_function):
jitted_function = numba.jit(integrand_function, nopython=True)
@cfunc(float64(intc, CPointer(float64)))
def wrapped(n, xx):
return jitted_function(xx[0], xx[1])
return LowLevelCallable(wrapped.ctypes)

@jit_integrand_function
def integrand(t, *args):
a = args[0]
c = fsolve(lambda x: a * x**2 - np.exp(-x**2 / a), 1)[0]
return c * np.exp(- (t / (a * c))**2)

def do_integrate(func, a):
return si.quad(func, 0, 1, args=(a,))

do_integrate(integrand, 2.)

但是,这个实现给了我错误


TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Failed in nopython mode pipeline (step: convert make_function into JIT functions)
Cannot capture the non-constant value associated with variable 'a' in a function that will escape.

File "<ipython-input-16-3d98286a4be7>", line 20:
def integrand(t, *args):
<source elided>
a = args[0]
c = fsolve(lambda x: a * x**2 - np.exp(-x**2 / a), 1)[0]
^

During: resolving callee type: type(CPUDispatcher(<function integrand at 0x11a949d08>))
During: typing of call at <ipython-input-16-3d98286a4be7> (14)

During: resolving callee type: type(CPUDispatcher(<function integrand at 0x11a949d08>))
During: typing of call at <ipython-input-16-3d98286a4be7> (14)

错误是因为我在被积函数中使用了 scipy.optimize 中的 fsolve。

我想知道是否有解决此错误的方法,以及是否可以在此上下文中将 scipy.optimize.fsolve 与 numba 一起使用。

最佳答案

我为 Minpack 编写了一个小的 python 包装器,称为 NumbaMinpack,它可以在 numba 编译函数中调用:https://github.com/Nicholaswogan/NumbaMinpack .你可以用它来 @njit 被积函数:

import scipy.integrate as si
from NumbaMinpack import hybrd, minpack_sig
from numba import njit, cfunc
import numpy as np

@cfunc(minpack_sig)
def f(x, fvec, args):
a = args[0]
fvec[0] = a * x[0]**2.0 - np.exp(-x[0]**2.0 / a)

funcptr = f.address # pointer to function

@njit
def integrand(t, *args):
a = args[0]
args_ = np.array(args)
x_init = np.array([1.0])
sol = hybrd(funcptr,x_init,args_)
c = sol[0][0]
return c * np.exp(- (t / (a * c))**2)

def do_integrate(func, a):
return si.quad(func, 0, 1, args=(a,))

print(do_integrate(integrand, 2.)[0])

在我的电脑上,上述代码耗时 87 µs,而纯 python 版本耗时 2920 µs

关于scipy - 有没有办法将 scipy.optimize.fsolve 与 jit_integrand_function 和 scipy.integrate.quad 一起使用?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62924588/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com