--- 使用 python 3 ---
根据方程 here ,我试图找到任意三阶多项式的所有实根。不幸的是,我的实现没有产生正确的结果,我也找不到错误。也许你能在眨眼间发现它并告诉我。
(如您所见,只有绿色曲线的根是错误的。)
致以最诚挚的问候
import numpy as np
def find_cubic_roots(a,b,c,d):
# with ax³ + bx² + cx + d = 0
a,b,c,d = a+0j, b+0j, c+0j, d+0j
all_ = (a != np.pi)
Q = (3*a*c - b**2)/ (9*a**2)
R = (9*a*b*c - 27*a**2*d - 2*b**3) / (54 * a**3)
D = Q**3 + R**2
S = (R + np.sqrt(D))**(1/3)
T = (R - np.sqrt(D))**(1/3)
result = np.zeros(tuple(list(a.shape) + [3])) + 0j
result[all_,0] = - b / (3*a) + (S+T)
result[all_,1] = - b / (3*a) - (S+T) / 2 + 0.5j * np.sqrt(3) * (S - T)
result[all_,2] = - b / (3*a) - (S+T) / 2 - 0.5j * np.sqrt(3) * (S - T)
return result
您看到的示例不起作用:
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
a = np.array([2.5])
b = np.array([-5])
c = np.array([0])
x = np.linspace(-2,3,100)
for i, d in enumerate([-8,0,8]):
d = np.array(d)
roots = find_cubic_roots(a,b,c,d)
ax.plot(x, a*x**3 + b*x**2 + c*x + d, label = "a = %.3f, b = %.3f, c = %.3f, d = %.3f"%(a,b,c,d), color = colors[i])
print(roots)
ax.plot(x, x*0)
ax.scatter(roots,roots*0, s = 80)
ax.legend(loc = 0)
ax.set_xlim(-2,3)
plt.show()
输出:
[[ 2.50852567+0.j -0.25426283+1.1004545j -0.25426283-1.1004545j]]
[[ 2.+0.j 0.+0.j 0.-0.j]]
[[ 1.51400399+1.46763129j 1.02750817-1.1867528j -0.54151216-0.28087849j]]
这是我的解决方案。如果 R + np.sqrt(D)
或 R - np.sqrt(D)
为负,您的代码将失败。原因在this post .基本上如果你执行 a**(1/3)
其中 a
是负数,numpy 返回一个复数。然而,事实上,我们希望 S
和 T
是实数,因为负实数的立方根只是负实数(让我们忽略 De Moivre's theorem 现在专注于代码而不是数学)。解决它的方法是检查 S
是否真实,将其转换为真实并将 S
传递给函数 from scipy.special import cbrt
。 T
也类似。示例代码:
import numpy as np
import pdb
import math
from scipy.special import cbrt
def find_cubic_roots(a,b,c,d, bp = False):
a,b,c,d = a+0j, b+0j, c+0j, d+0j
all_ = (a != np.pi)
Q = (3*a*c - b**2)/ (9*a**2)
R = (9*a*b*c - 27*a**2*d - 2*b**3) / (54 * a**3)
D = Q**3 + R**2
S = 0 #NEW CALCULATION FOR S STARTS HERE
if np.isreal(R + np.sqrt(D)):
S = cbrt(np.real(R + np.sqrt(D)))
else:
S = (R + np.sqrt(D))**(1/3)
T = 0 #NEW CALCULATION FOR T STARTS HERE
if np.isreal(R - np.sqrt(D)):
T = cbrt(np.real(R - np.sqrt(D)))
else:
T = (R - np.sqrt(D))**(1/3)
result = np.zeros(tuple(list(a.shape) + [3])) + 0j
result[all_,0] = - b / (3*a) + (S+T)
result[all_,1] = - b / (3*a) - (S+T) / 2 + 0.5j * np.sqrt(3) * (S - T)
result[all_,2] = - b / (3*a) - (S+T) / 2 - 0.5j * np.sqrt(3) * (S - T)
#if bp:
#pdb.set_trace()
return result
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
a = np.array([2.5])
b = np.array([-5])
c = np.array([0])
x = np.linspace(-2,3,100)
for i, d in enumerate([-8,0,8]):
d = np.array(d)
if d == 8:
roots = find_cubic_roots(a,b,c,d, True)
else:
roots = find_cubic_roots(a,b,c,d)
ax.plot(x, a*x**3 + b*x**2 + c*x + d, label = "a = %.3f, b = %.3f, c = %.3f, d = %.3f"%(a,b,c,d))
print(roots)
ax.plot(x, x*0)
ax.scatter(roots,roots*0, s = 80)
ax.legend(loc = 0)
ax.set_xlim(-2,3)
plt.show()
免责声明:输出根给出了一些警告,您可以可能忽略这些警告。输出是正确的。但是,由于某些原因,绘图显示了一个额外的根。这可能是由于您的绘图代码。不过打印的根部看起来不错。
我是一名优秀的程序员,十分优秀!