gpt4 book ai didi

python - 决策树未捕获因变量的方差

转载 作者:行者123 更新时间:2023-11-30 09:31:44 24 4
gpt4 key购买 nike

我正在使用决策树回归器。数据点数量为15000个,有15个特征。我面临的问题是,即使在高度过度拟合的条件下(我设置深度 = 25,叶子上的 min.samples = 2),预测的方差也比因变量低得多(即它仍然欠拟合)。首先,我认为这可能是偏差方差问题,但是所有预测的平均值和因变量的平均值等于小数点后 9 位。

即它看起来像: enter image description here

因此,预测和因变量的 View 如下: enter image description here

我能想到的一个原因是我选择的功能可能根本不重要。然而它们确实是有道理的。

有人可以解释一下这里可能出了什么问题吗?任何帮助将不胜感激。谢谢

最佳答案

抛开您自己的数据细节不谈,一旦您了解决策树在幕后实际执行的操作,这原则上就不足为奇了。

回归树实际返回的输出是训练样本的因变量y平均值,这些样本最终出现在各自的终端节点(叶子)中。实际上,这意味着默认情况下输出是离散化的:在输出处获得的值位于终端节点中的有限值集中,它们之间没有任何插值。

鉴于此,直观上预测的方差低于实际值并不令人惊讶,具体低多少取决于终端节点的数量(即 max_depth) ,当然还有数据本身。

以下情节来自 documentation应该有助于形象化这个想法 - 应该直观地清楚地看到数据的方差确实高于(离散的)预测之一:

enter image description here

让我们调整该示例中的代码,添加一些异常值(这会放大问题):

import numpy as np
from sklearn.tree import DecisionTreeRegressor

# dummy data
rng = np.random.RandomState(1)
X = np.sort(5 * rng.rand(80, 1), axis=0)
y = np.sin(X).ravel()
y[::5] += 3 * (0.5 - 5*rng.rand(16)) # modify here - 5*

estimator_1 = DecisionTreeRegressor(max_depth=2)
estimator_1.fit(X, y)

estimator_2 = DecisionTreeRegressor(max_depth=5)
estimator_2.fit(X, y)

y_pred_1 = estimator_1.predict(X)
y_pred_2 = estimator_2.predict(X)

现在让我们检查差异:

np.var(y) # true data
# 11.238416688700267

np.var(y_pred_1) # max_depth=2
# 1.7423865989859313

np.var(y_pred_2) # max_depth=5
# 6.1398871265574595

正如预期的那样,预测的方差随着树深度的增加而增加,但它仍然(显着)低于真实数据的方差。当然,所有的平均值都是相同的:

np.mean(y)
# -1.2561013675900665

np.mean(y_pred_1)
# -1.2561013675900665

np.mean(y_pred_2)
# -1.2561013675900665

所有这些对于新手来说可能看起来令人惊讶,特别是如果他们试图“天真地”扩展线性回归的线性思维;但是决策树存在于它们自己的领域中,这当然与线性树不同(而且相当远)。

回到我在答案中提到的离散化问题,让我们检查一下我们的预测得到了多少个唯一值;为了简单起见,仅将讨论保留在 y_pred_1 中:

np.unique(y_pred_1)
# array([-11.74901949, -1.9966201 , -0.71895532])

就是这样;从回归树中获得的每个输出都将是这三个值之一,并且绝不会任何“介于”之间的值,例如-10-5.82 或 [...](即无插值)。现在,至少从直觉上讲,您应该能够说服自己,这种情况下的方差不出所料地(远...)低于实际数据(默认情况下,预测的分散程度较小)...

关于python - 决策树未捕获因变量的方差,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55042015/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com