gpt4 book ai didi

python - 在python中找到四次多项式4次最小正实根的最快方法

转载 作者:太空狗 更新时间:2023-10-30 01:00:20 24 4
gpt4 key购买 nike

[我要的]是求四次函数ax^4 + bx^唯一最小的正实根3 + cx^2 + dx + e

[现有方法]我的方程用于碰撞预测,最大次数是四次函数 f(x) = ax^4 + bx^3 + cx^2 + dx + ea,b,c,d,e 系数可以是正/负/零(实际浮点值)。所以我的函数 f(x) 可以是四次函数、三次函数或二次函数,具体取决于 a、b、c、d 和 e 输入系数。

目前,我使用 NumPy 来查找根,如下所示。

import numpy

root_output = numpy.roots([a, b, c ,d ,e])

NumPy 模块中的“root_output”可以是所有可能的实根/复根,具体取决于输入系数。所以我必须一个一个地看“root_output”,检查哪个根是最小的实数正值(root>0?)

[问题]我的程序需要多次执行 numpy.roots([a, b, c, d, e]) ,多次执行 numpy.roots 对我的项目来说太慢了。并且 (a, b, c ,d ,e) 值总是在每次执行 numpy.roots 时改变

我的尝试是在 Raspberry Pi2 上运行代码。以下是处理时间的示例。

  • 在 PC 上多次运行 numpy.roots:1.2 秒
  • 在 Raspberry Pi2 上多次运行 numpy.roots:17 秒

能否请您指导我如何在最快的解决方案中找到最小的正实根?使用 scipy.optimize 或实现一些算法来加速查找根或您的任何建议都会很棒。

谢谢。

[解决方案]

  • 二次函数只需要正实根(请注意被零除)
def SolvQuadratic(a, b ,c):
d = (b**2) - (4*a*c)
if d < 0:
return []

if d > 0:
square_root_d = math.sqrt(d)
t1 = (-b + square_root_d) / (2 * a)
t2 = (-b - square_root_d) / (2 * a)
if t1 > 0:
if t2 > 0:
if t1 < t2:
return [t1, t2]
return [t2, t1]
return [t1]
elif t2 > 0:
return [t2]
else:
return []
else:
t = -b / (2*a)
if t > 0:
return [t]
return []
  • Quartic Function 对于四次函数,您可以使用纯 python/numba 版本作为@B.M. 的以下答案。我还从@B.M 的代码中添加了另一个 cython 版本。您可以将以下代码用作 .pyx 文件,然后编译它以获得比纯 python 快约 2 倍的速度(请注意舍入问题)。
import cmath

cdef extern from "complex.h":
double complex cexp(double complex)

cdef double complex J=cexp(2j*cmath.pi/3)
cdef double complex Jc=1/J

cdef Cardano(double a, double b, double c, double d):
cdef double z0
cdef double a2, b2
cdef double p ,q, D
cdef double complex r
cdef double complex u, v, w
cdef double w0, w1, w2
cdef double complex r1, r2, r3


z0=b/3/a
a2,b2 = a*a,b*b
p=-b2/3/a2 +c/a
q=(b/27*(2*b2/a2-9*c/a)+d)/a
D=-4*p*p*p-27*q*q
r=cmath.sqrt(-D/27+0j)
u=((-q-r)/2)**0.33333333333333333333333
v=((-q+r)/2)**0.33333333333333333333333
w=u*v
w0=abs(w+p/3)
w1=abs(w*J+p/3)
w2=abs(w*Jc+p/3)
if w0<w1:
if w2<w0 : v = v*Jc
elif w2<w1 : v = v*Jc
else: v = v*J
r1 = u+v-z0
r2 = u*J+v*Jc-z0
r3 = u*Jc+v*J-z0
return r1, r2, r3

cdef Roots_2(double a, double complex b, double complex c):
cdef double complex bp
cdef double complex delta
cdef double complex r1, r2


bp=b/2
delta=bp*bp-a*c
r1=(-bp-delta**.5)/a
r2=-r1-b/a
return r1, r2

def SolveQuartic(double a, double b, double c, double d, double e):
"Ferrarai's Method"
"resolution of P=ax^4+bx^3+cx^2+dx+e=0, coeffs reals"
"First shift : x= z-b/4/a => P=z^4+pz^2+qz+r"
cdef double z0
cdef double a2, b2, c2, d2
cdef double p, q, r
cdef double A, B, C, D
cdef double complex y0, y1, y2
cdef double complex a0, b0
cdef double complex r0, r1, r2, r3


z0=b/4.0/a
a2,b2,c2,d2 = a*a,b*b,c*c,d*d
p = -3.0*b2/(8*a2)+c/a
q = b*b2/8.0/a/a2 - 1.0/2*b*c/a2 + d/a
r = -3.0/256*b2*b2/a2/a2 + c*b2/a2/a/16 - b*d/a2/4+e/a
"Second find y so P2=Ay^3+By^2+Cy+D=0"
A=8.0
B=-4*p
C=-8*r
D=4*r*p-q*q
y0,y1,y2=Cardano(A,B,C,D)
if abs(y1.imag)<abs(y0.imag): y0=y1
if abs(y2.imag)<abs(y0.imag): y0=y2
a0=(-p+2*y0)**.5
if a0==0 : b0=y0**2-r
else : b0=-q/2/a0
r0,r1=Roots_2(1,a0,y0+b0)
r2,r3=Roots_2(1,-a0,y0-b0)
return (r0-z0,r1-z0,r2-z0,r3-z0)

[Ferrari 方法的问题] 当四次方程的系数为 [0.00614656, -0.0933333333333, 0.527664995846, -1.31617928376, 1.21906444869] 来自 numpy.roots 和 ferrari 的输出时,我们面临着这个问题方法完全不同(numpy.roots 是正确的输出)。

import numpy as np
import cmath


J=cmath.exp(2j*cmath.pi/3)
Jc=1/J

def ferrari(a,b,c,d,e):
"Ferrarai's Method"
"resolution of P=ax^4+bx^3+cx^2+dx+e=0, coeffs reals"
"First shift : x= z-b/4/a => P=z^4+pz^2+qz+r"
z0=b/4/a
a2,b2,c2,d2 = a*a,b*b,c*c,d*d
p = -3*b2/(8*a2)+c/a
q = b*b2/8/a/a2 - 1/2*b*c/a2 + d/a
r = -3/256*b2*b2/a2/a2 +c*b2/a2/a/16-b*d/a2/4+e/a
"Second find y so P2=Ay^3+By^2+Cy+D=0"
A=8
B=-4*p
C=-8*r
D=4*r*p-q*q
y0,y1,y2=Cardano(A,B,C,D)
if abs(y1.imag)<abs(y0.imag): y0=y1
if abs(y2.imag)<abs(y0.imag): y0=y2
a0=(-p+2*y0)**.5
if a0==0 : b0=y0**2-r
else : b0=-q/2/a0
r0,r1=Roots_2(1,a0,y0+b0)
r2,r3=Roots_2(1,-a0,y0-b0)
return (r0-z0,r1-z0,r2-z0,r3-z0)

#~ @jit(nopython=True)
def Cardano(a,b,c,d):
z0=b/3/a
a2,b2 = a*a,b*b
p=-b2/3/a2 +c/a
q=(b/27*(2*b2/a2-9*c/a)+d)/a
D=-4*p*p*p-27*q*q
r=cmath.sqrt(-D/27+0j)
u=((-q-r)/2)**0.33333333333333333333333
v=((-q+r)/2)**0.33333333333333333333333
w=u*v
w0=abs(w+p/3)
w1=abs(w*J+p/3)
w2=abs(w*Jc+p/3)
if w0<w1:
if w2<w0 : v*=Jc
elif w2<w1 : v*=Jc
else: v*=J
return u+v-z0, u*J+v*Jc-z0, u*Jc+v*J-z0

#~ @jit(nopython=True)
def Roots_2(a,b,c):
bp=b/2
delta=bp*bp-a*c
r1=(-bp-delta**.5)/a
r2=-r1-b/a
return r1,r2

coef = [0.00614656, -0.0933333333333, 0.527664995846, -1.31617928376, 1.21906444869]
print("Coefficient A, B, C, D, E", coef)
print("")
print("numpy roots: ", np.roots(coef))
print("")
print("ferrari python ", ferrari(*coef))

最佳答案

另一个答案:

使用分析方法(FerrariCardan),并使用即时编译(Numba)加速代码:

让我们先看看改进:

In [2]: P=poly1d([1,2,3,4],True)

In [3]: roots(P)
Out[3]: array([ 4., 3., 2., 1.])

In [4]: %timeit roots(P)
1000 loops, best of 3: 465 µs per loop

In [5]: ferrari(*P.coeffs)
Out[5]: ((1+0j), (2-0j), (3+0j), (4-0j))

In [5]: %timeit ferrari(*P.coeffs) #pure python without jit
10000 loops, best of 3: 116 µs per loop
In [6]: %timeit ferrari(*P.coeffs) # with numba.jit
100000 loops, best of 3: 13 µs per loop

然后是丑陋的代码:

对于订单 4:

@jit(nopython=True)
def ferrari(a,b,c,d,e):
"resolution of P=ax^4+bx^3+cx^2+dx+e=0"
"CN all coeffs real."
"First shift : x= z-b/4/a => P=z^4+pz^2+qz+r"
z0=b/4/a
a2,b2,c2,d2 = a*a,b*b,c*c,d*d
p = -3*b2/(8*a2)+c/a
q = b*b2/8/a/a2 - 1/2*b*c/a2 + d/a
r = -3/256*b2*b2/a2/a2 +c*b2/a2/a/16-b*d/a2/4+e/a
"Second find X so P2=AX^3+BX^2+C^X+D=0"
A=8
B=-4*p
C=-8*r
D=4*r*p-q*q
y0,y1,y2=cardan(A,B,C,D)
if abs(y1.imag)<abs(y0.imag): y0=y1
if abs(y2.imag)<abs(y0.imag): y0=y2
a0=(-p+2*y0.real)**.5
if a0==0 : b0=y0**2-r
else : b0=-q/2/a0
r0,r1=roots2(1,a0,y0+b0)
r2,r3=roots2(1,-a0,y0-b0)
return (r0-z0,r1-z0,r2-z0,r3-z0)

对于订单 3:

J=exp(2j*pi/3)
Jc=1/J

@jit(nopython=True)
def cardan(a,b,c,d):
u=empty(2,complex128)
z0=b/3/a
a2,b2 = a*a,b*b
p=-b2/3/a2 +c/a
q=(b/27*(2*b2/a2-9*c/a)+d)/a
D=-4*p*p*p-27*q*q
r=sqrt(-D/27+0j)
u=((-q-r)/2)**0.33333333333333333333333
v=((-q+r)/2)**0.33333333333333333333333
w=u*v
w0=abs(w+p/3)
w1=abs(w*J+p/3)
w2=abs(w*Jc+p/3)
if w0<w1:
if w2<w0 : v*=Jc
elif w2<w1 : v*=Jc
else: v*=J
return u+v-z0, u*J+v*Jc-z0,u*Jc+v*J-z0

对于订单 2:

@jit(nopython=True)
def roots2(a,b,c):
bp=b/2
delta=bp*bp-a*c
u1=(-bp-delta**.5)/a
u2=-u1-b/a
return u1,u2

可能需要进一步测试,但效率很高。

关于python - 在python中找到四次多项式4次最小正实根的最快方法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/35795663/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com