gpt4 book ai didi

python - 如何获取scikit-learn的DecisionTreeRegressor中节点的MSE?

转载 作者:行者123 更新时间:2023-12-01 16:39:39 24 4
gpt4 key购买 nike

生成的决策树回归模型中,使用graphviz查看树结构时有一个MSE属性。我需要获取每个叶子节点的MSE,并根据MSE进行后续操作。但是,读完文档后,我找不到提供输出MSE的方法。其他属性如特征名称、样本数、预测值等都有对应的方法:

Tree structure

使用 help(sklearn.tree._tree.Tree),我可以看到大多数属性都有一些输出值的方法,但我没有看到任何有关 MSE 的信息。

有关模块 sklearn.tree._tree 中的 Tree 类的帮助 Help on class Tree in module sklearn.tree._tree

最佳答案

问得好。您需要tree_reg.tree_.impurity

简短回答:

tree_reg = tree.DecisionTreeRegressor(max_depth=2)
tree_reg.fit(X_train, y_train)

extracted_MSEs = tree_reg.tree_.impurity # The Hidden magic is HERE

for idx, MSE in enumerate(tree_reg.tree_.impurity):
print("Node {} has MSE {}".format(idx,MSE))

Node 0 has MSE 86.873403833
Node 1 has MSE 40.3211827171
Node 2 has MSE 25.6934820064
Node 3 has MSE 19.0053469592
Node 4 has MSE 74.6839429717
Node 5 has MSE 38.3057346817
Node 6 has MSE 39.6709615385

使用带有视觉输出的 boston 数据集的长答案:

import pandas as pd
import numpy as np
from sklearn import ensemble, model_selection, metrics, datasets, tree
import graphviz

house_prices = datasets.load_boston()

X_train, X_test, y_train, y_test = model_selection.train_test_split(
pd.DataFrame(house_prices.data, columns=house_prices.feature_names),
pd.Series(house_prices.target, name="med_price"),
test_size=0.20, random_state=42)

tree_reg = tree.DecisionTreeRegressor(max_depth=2)
tree_reg.fit(X_train, y_train)

extracted_MSEs = tree_reg.tree_.impurity # YOU NEED THIS
print(extracted_MSEs)
#[86.87340383 40.32118272 25.69348201 19.00534696 74.68394297 38.30573468 39.67096154]

# Compare visually
dot_data = tree.export_graphviz(tree_reg, out_file=None, feature_names=X_train.columns)
graph = graphviz.Source(dot_data)

#this will create an boston.pdf file with the rule path
graph.render("boston")

将 MSE 值与视觉输出进行比较:

enter image description here

关于python - 如何获取scikit-learn的DecisionTreeRegressor中节点的MSE?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59375220/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com