gpt4 book ai didi

python - 分段回归 Python

转载 作者:行者123 更新时间:2023-11-28 16:27:23 44 4
gpt4 key购买 nike

您好,我正在尝试弄清楚如何使用分段线性函数来拟合这些值。我已经阅读了这个问题,但我无法继续( How to apply piecewise linear fit in Python? )。在这个例子中展示了如何为 2 段案例实现分段函数。但是我需要在如图所示的三段情况下进行。 Three segment data

我已经写了这段代码:

from scipy import optimize
import matplotlib.pyplot as plt
import numpy as np


x1 = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10 ,11, 12, 13, 14, 15,16,17,18,19,20,21], dtype=float)
y1 = np.array([5, 7, 9, 11, 13, 15, 28.92, 42.81, 56.7, 70.59, 84.47, 98.36, 112.25, 126.14, 140.03,145,147,149,151,153,155])



x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10 ,11, 12, 13, 14, 15], dtype=float)
y = np.array([5, 7, 9, 11, 13, 15, 28.92, 42.81, 56.7, 70.59, 84.47, 98.36, 112.25, 126.14, 140.03])

def piecewise(x,x0,x1,y0,y1,k0,k1,k2):
return np.piecewise(x , [x <= x0, (x>= x1)] , [lambda x:k0*x + y0-k0*x0, lambda x:k1*(x-(x1+x0))-y1, lambda x:k2*x + y1-k2*x1])

p , e = optimize.curve_fit(piecewise_linear, x1, y1)
xd = np.linspace(0, 15, 100)
plt.figure()
plt.plot(x1, y1, "o")
plt.plot(xd, piecewise_linear(xd, *p))

但这是输出

enter image description here

有什么建议吗?我相信问题出在 return np.piecewise(x , [x <= x0, (x>= x1)] , [lambda x:k0*x + y0-k0*x0, lambda x:k1*(x-(x1+x0))-y1, lambda x:k2*x + y1-k2*x1]) 中,特别是在第二个 lambda 中。

编辑 1:

如果我尝试使用 A.L. 提供的解决方案的不同数据,我不会得到好的结果。

enter image description here

我得到这个结果:

enter image description here

x=[ 16.01690476,  16.13801587,  14.63628571,  15.32664399,
15.8145 , 15.71507143, 15.56107143, 15.553 ,
15.08734524, 14.97275 , 15.51958333, 16.61981859,
16.36589286, 14.78708333, 14.41565476, 13.47763158,
13.42412281, 12.95551378, 13.66601504, 13.63315789,
13.21463659, 13.53464286, 14.60130952, 14.7774881 ,
13.04319048, 12.53385965, 12.65745614, 13.90535714,
14.82412281, 14.6565 , 15.09541667, 13.41434524,
13.66033333, 14.57964286, 13.55416667, 13.43041667,
13.01137566, 12.76429825, 11.55241667, 11.0634881 ,
10.92729762, 11.21625 , 10.72092857, 11.80380952,
12.55233333, 12.11307143, 11.78892857, 12.45458333,
11.05539286, 10.69214286, 10.32566667, 11.3439881 ,
9.69563492, 10.72535714, 10.26180272, 7.77272727,
6.37704082, 8.49666667, 8.5389881 , 5.68547619,
7.00616667, 8.22015873, 10.20315476, 15.35736842,
12.25158333, 11.09622153, 10.4118254 , 9.8602381 ,
10.16727273, 15.10858333, 13.82215539, 12.44719298,
10.92341667, 11.44565476, 11.43333333, 10.5045 ,
11.14357143, 10.37625 , 8.93421769, 9.48444444,
10.43483333, 10.8659881 , 10.96166667, 10.12872619,
9.64663265, 9.29979762, 9.67173469, 8.978322 ,
9.10419501, 9.45411565, 10.46411565, 7.95739229,
8.72616667, 7.03892857, 7.32547619, 7.56441667,
6.61022676, 9.09014739, 10.78141667, 10.85918367,
11.11665476, 10.141 , 9.17760771, 8.27968254,
11.02625 , 12.34809524, 11.17807018, 11.25416667,
11.29236905, 9.28357143, 9.77033333, 11.52086168,
9.8625 , 12.60281955, 12.42785714, 12.11902256,
13.1 , 13.02791667, 13.87779449, 15.09857143,
13.93935185, 13.69821429, 13.39880952, 12.45692982,
12.76921053, 13.23708333, 13.71666667, 15.39807143,
15.27916667, 14.66464286, 13.38694444, 10.97555556,
10.02191667, 11.99608333, 14.26325 , 15.40991667,
15.12908333, 15.76265476, 12.12763158, 15.01641667,
14.39602381, 12.98532143, 14.98807018, 18.30547619,
16.7564966 , 16.82982143, 19.8487013 , 19.18600907]

y=[ 2.36846863,  2.73722628,  2.77177583,  2.63930636,  2.80864749,
2.57066667, 2.65277287, 2.57162347, 2.76295667, 2.79835391,
2.60431154, 2.17326401, 2.67740698, 2.47138153, 2.49882574,
2.60987338, 2.69935565, 2.60755362, 2.77702029, 2.62996942,
2.45959517, 2.52750434, 2.73833005, 2.52009 , 2.80933226,
1.63807085, 2.49230099, 2.55441614, 3.19256506, 2.52609288,
1.02931596, 2.40266963, 2.3306463 , 2.69094276, 2.60779985,
2.48351648, 2.45131766, 2.40526763, 2.03952569, 1.86217009,
1.79971848, 1.91772218, 1.85895421, 2.32725731, 2.28189713,
2.11835833, 2.09636517, 2.2230303 , 1.85863317, 1.77550406,
1.68862391, 1.79187765, 1.70887476, 1.81911193, 1.74802483,
1.65776432, 1.58012849, 1.67781494, 1.62451541, 1.60555884,
1.56172214, 1.60083809, 1.65256994, 2.74794704, 2.27089627,
1.80364982, 1.51412482, 1.77738757, 1.56979564, 2.46538633,
2.37679625, 2.40389294, 2.04165763, 1.82086407, 1.90609219,
1.87480978, 1.8877854 , 1.76080074, 1.68369028, 1.57419297,
1.66470126, 1.74522552, 1.72459756, 1.65510503, 1.72131148,
1.6254417 , 1.57091907, 1.68755268, 1.70307911, 1.59445121,
1.74393783, 1.72913779, 1.66883237, 1.59859545, 1.62335831,
1.73378184, 1.62621588, 1.79532164, 1.78289992, 1.79475101,
1.7826266 , 1.68778918, 1.64484127, 1.62332696, 1.75372393,
1.99038021, 1.87268137, 1.86124502, 1.82435911, 1.62927102,
1.66443723, 1.86743516, 1.62745098, 2.20200312, 2.09641026,
2.26649111, 2.63271605, 2.18050721, 2.57138433, 2.51833359,
2.74684184, 2.57209998, 2.63762019, 2.30027877, 2.28471286,
2.40323668, 2.37103313, 2.16414489, 1.01027109, 2.64181007,
2.45467765, 2.05773672, 1.73624917, 2.05233688, 2.70820669,
2.65594222, 2.67445635, 2.37212985, 2.48221803, 2.77655216,
2.62839879, 2.26481307, 2.58005799, 2.1188172 , 2.14017268,
2.16459571, 1.95083406, 1.46224418]

最佳答案

拟合分段线性函数是一个可能具有局部最优值的非线性优化问题。您看到的结果可能是您的优化算法卡住的局部最优解之一。

解决此问题的一种方法是使用不同的初始值重复您的优化算法并采用最佳拟合。我使用平均绝对误差 (MAE) 来比较不同的拟合。

perr = np.sum(np.abs(y1-piecewise(x1, *p)))

我还更改了您的分段函数,因为它让我有点困惑。但还是和之前一样是分段函数

进一步认为您忘记将 x 和 xd 数组扩展到值 21。(这就是绿线提前结束的原因)。

from scipy import optimize
import matplotlib.pyplot as plt
import numpy as np


def piecewise(x,x0,x1,y0,y1,k0,k1,k2):
return np.piecewise(x , [x <= x0, np.logical_and(x0<x, x<= x1),x>x1] , [lambda x:k0*x + y0, lambda x:k1*(x-x0)+y1+k0*x0,
lambda x:k2*(x-x1) + y0+y1+k0*x0+k1*(x1-x0)])

x1 = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10 ,11, 12, 13, 14, 15,16,17,18,19,20,21], dtype=float)
y1 = np.array([5, 7, 9, 11, 13, 15, 28.92, 42.81, 56.7, 70.59, 84.47, 98.36, 112.25, 126.14, 140.03,145,147,149,151,153,155])
x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10 ,11, 12, 13, 14, 15,16,17,18,19,20,21], dtype=float)
y = np.array([5, 7, 9, 11, 13, 15, 28.92, 42.81, 56.7, 70.59, 84.47, 98.36, 112.25, 126.14, 140.03,145,147,149,151,153,155])


perr_min = np.inf
p_best = None
for n in range(100):
k = np.random.rand(7)*20
p , e = optimize.curve_fit(piecewise, x1, y1,p0=k)
perr = np.sum(np.abs(y1-piecewise(x1, *p)))
if(perr < perr_min):
perr_min = perr
p_best = p

xd = np.linspace(0, 21, 100)
plt.figure()
plt.plot(x1, y1, "o")
y_out = piecewise(xd, *p_best)
plt.plot(xd, y_out)
plt.show()

这给了我: enter image description here

p = [ 6.34259491 15.00000023 2.97272604 7.05498314 2.00751828 13.88881542 1.99960597]

编辑1

您编辑了您的问题,这是编辑后问题的答案。抱歉,我是 stackoverlfow 的新手,不确定是否应该发布另一个答案

在您的第二个数据集中,您向数据添加了噪音。在我看来,有两种噪音。高斯分布,将点置于靠近底层分段线的位置,离群噪声将点置于远离原始底层线的位置。

在引擎盖下,您使用的优化算法根据 p 优化以下内容:E = sum(square(y-piecewise(x,p))) http://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.curve_fit.html#scipy.optimize.curve_fit

高斯噪声问题不大。您使用的优化间接假设了这种高斯噪声(通过最小化最小二乘误差)并尽可能好地拟合直线。真正的问题来自异常值。

问题是异常值与原始函数相去甚远。即使优化尝试最佳参数,能量函数 E 也不会最小,因为您的离群值远离原始函数并且该距离是偶数平方的,因此它使函数 E 的最小值远离真实参数你的功能。

那么解决方案是什么?摆脱异常值。

一种自动化的方法是 ransac https://en.wikipedia.org/wiki/RANSAC .

简而言之:您选择原始数据的随机子集。您希望子集没有异常值。您将函数拟合到子集并丢弃远离拟合函数的点。如果此步骤中有足够的点幸存下来,您将获取所有幸存的点并重复拟合。此“内点”集的误差是衡量拟合质量的指标。然后你重复整个过程并选择最佳的最终拟合。

我相应地调整了我的脚本:

from scipy import optimize
import matplotlib.pyplot as plt
import numpy as np
def piecewise(x,x0,x1,y0,y1,k0,k1,k2):
return np.piecewise(x , [x <= x0, np.logical_and(x0<x, x<= x1),x>x1] , [lambda x:k0*x + y0, lambda x:k1*(x-x0)+y1+k0*x0,
lambda x:k2*(x-x1) + y0+y1+k0*x0+k1*(x1-x0)])

x = np.array(x)
y = np.array(y)

x1 = x
y1 = y



perr_min = np.inf
p_best = None
for n in range(100):
idx = np.random.choice(np.arange(len(x)), 10, replace=False)
x_sample = x[idx]
y_sample = y[idx]
k = np.random.rand(7)*20
try:
p , e = optimize.curve_fit(piecewise, x_sample,y_sample ,p0=k)
each_error = np.abs(y-piecewise(x, *p))
x_inliner = x[each_error < 1]
y_inlier = y[each_error < 1]
if(x_inliner.shape[0] < 0.8 * x.shape[0]):
continue

p_inlier , e_inlier = optimize.curve_fit(piecewise, x_inliner,y_inlier ,p0=p)
perr = np.sum(np.abs(y-piecewise(x, *p_inlier)))


if(perr < perr_min):
perr_min = perr
p_best = p_inlier
except RuntimeError:
pass

xd = np.linspace(0, 21, 100)
plt.figure()
plt.plot(x, y, "o")
y_out = piecewise(xd, *p_best)
plt.plot(xd, y_out)
print p_best
plt.show()

重复 100 次后,我得到以下结果: Fittting the curve with Ransac and least squares

关于python - 分段回归 Python,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/35415372/

44 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com