gpt4 book ai didi

python - 修剪决策树

转载 作者:太空狗 更新时间:2023-10-29 17:05:12 26 4
gpt4 key购买 nike

下面是决策树的一个片段,因为它非常大。

enter image description here

当节点中的最低小于5时,如何使树停止生长。下面是生成决策树的代码。在 SciKit - Decission Tree我们可以看到唯一的方法是通过min_impurity_decrease,但我不确定它具体是如何工作的。

import numpy as np
import pandas as pd
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier


X, y = make_classification(n_samples=1000,
n_features=6,
n_informative=3,
n_classes=2,
random_state=0,
shuffle=False)

# Creating a dataFrame
df = pd.DataFrame({'Feature 1':X[:,0],
'Feature 2':X[:,1],
'Feature 3':X[:,2],
'Feature 4':X[:,3],
'Feature 5':X[:,4],
'Feature 6':X[:,5],
'Class':y})


y_train = df['Class']
X_train = df.drop('Class',axis = 1)

dt = DecisionTreeClassifier( random_state=42)
dt.fit(X_train, y_train)

from IPython.display import display, Image
import pydotplus
from sklearn import tree
from sklearn.tree import _tree
from sklearn import tree
import collections
import drawtree
import os

os.environ["PATH"] += os.pathsep + 'C:\\Anaconda3\\Library\\bin\\graphviz'

dot_data = tree.export_graphviz(dt, out_file = 'thisIsTheImagetree.dot',
feature_names=X_train.columns, filled = True
, rounded = True
, special_characters = True)

graph = pydotplus.graph_from_dot_file('thisIsTheImagetree.dot')

thisIsTheImage = Image(graph.create_png())
display(thisIsTheImage)
#print(dt.tree_.feature)

from subprocess import check_call
check_call(['dot','-Tpng','thisIsTheImagetree.dot','-o','thisIsTheImagetree.png'])

更新

我认为 min_impurity_decrease 可以在某种程度上帮助实现目标。由于调整 min_impurity_decrease 实际上会修剪树。谁能解释一下 min_impurity_decrease。

我试图理解 scikit learn 中的等式,但我不确定 right_impurity 和 left_impurity 的值是多少。

N = 256
N_t = 256
impurity = ??
N_t_R = 242
N_t_L = 14
right_impurity = ??
left_impurity = ??

New_Value = N_t / N * (impurity - ((N_t_R / N_t) * right_impurity)
- ((N_t_L / N_t) * left_impurity))
New_Value

更新2

我们不是在特定值下修剪,而是在特定条件下修剪。如我们确实以 6/4 和 5/5 拆分,但不以 6000/4 或 5000/5 拆分。假设一个值与其在节点中的相邻值相比是否低于特定百分比,而不是特定值。

      11/9
/ \
6/4 5/5
/ \ / \
6/0 0/4 2/2 3/3

最佳答案

不能使用 min_impurity_decrease 或任何其他内置停止条件来直接限制叶子的最低值(特定类别的出现次数)。

我认为您可以在不更改 scikit-learn 源代码的情况下完成此操作的唯一方法是后修剪您的树。为此,您只需遍历树并删除最小类数小于 5(或您能想到的任何其他条件)的节点的所有子节点。我将继续你的例子:

from sklearn.tree._tree import TREE_LEAF

def prune_index(inner_tree, index, threshold):
if inner_tree.value[index].min() < threshold:
# turn node into a leaf by "unlinking" its children
inner_tree.children_left[index] = TREE_LEAF
inner_tree.children_right[index] = TREE_LEAF
# if there are shildren, visit them as well
if inner_tree.children_left[index] != TREE_LEAF:
prune_index(inner_tree, inner_tree.children_left[index], threshold)
prune_index(inner_tree, inner_tree.children_right[index], threshold)

print(sum(dt.tree_.children_left < 0))
# start pruning from the root
prune_index(dt.tree_, 0, 5)
sum(dt.tree_.children_left < 0)

此代码将首先打印 74,然后打印 91。这意味着代码已经创建了 17 个新的叶节点(实际上是通过删除到它们祖先的链接)。以前看起来像的树

enter image description here

现在看起来像

enter image description here

所以你可以看到确实减少了很多。

关于python - 修剪决策树,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49428469/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com