gpt4 book ai didi

python - 如何探索使用 scikit learn 构建的决策树

转载 作者:太空狗 更新时间:2023-10-29 18:16:01 25 4
gpt4 key购买 nike

我正在使用

构建决策树
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X_train, Y_train)

一切正常。但是,我该如何探索决策树?

例如,我如何找到 X_train 中的哪些条目出现在特定的叶子中?

最佳答案

您需要使用预测方法。

在训练树之后,您输入 X 值来预测它们的输出。

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier(random_state=0)
iris = load_iris()
tree = clf.fit(iris.data, iris.target)
tree.predict(iris.data)

输出:

>>> tree.predict(iris.data)
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

要获取树结构的详细信息,我们可以使用tree_.__getstate__()

翻译成“ASCII艺术”图片的树结构

              0  
_____________
1 2
______________
3 12
_______ _______
4 7 13 16
___ ______ _____
5 6 8 9 14 15
_____
10 11

作为数组的树结构。

In [38]: tree.tree_.__getstate__()['nodes']
Out[38]:
array([(1, 2, 3, 0.800000011920929, 0.6666666666666667, 150, 150.0),
(-1, -1, -2, -2.0, 0.0, 50, 50.0),
(3, 12, 3, 1.75, 0.5, 100, 100.0),
(4, 7, 2, 4.949999809265137, 0.16803840877914955, 54, 54.0),
(5, 6, 3, 1.6500000953674316, 0.04079861111111116, 48, 48.0),
(-1, -1, -2, -2.0, 0.0, 47, 47.0),
(-1, -1, -2, -2.0, 0.0, 1, 1.0),
(8, 9, 3, 1.5499999523162842, 0.4444444444444444, 6, 6.0),
(-1, -1, -2, -2.0, 0.0, 3, 3.0),
(10, 11, 2, 5.449999809265137, 0.4444444444444444, 3, 3.0),
(-1, -1, -2, -2.0, 0.0, 2, 2.0),
(-1, -1, -2, -2.0, 0.0, 1, 1.0),
(13, 16, 2, 4.850000381469727, 0.042533081285444196, 46, 46.0),
(14, 15, 1, 3.0999999046325684, 0.4444444444444444, 3, 3.0),
(-1, -1, -2, -2.0, 0.0, 2, 2.0),
(-1, -1, -2, -2.0, 0.0, 1, 1.0),
(-1, -1, -2, -2.0, 0.0, 43, 43.0)],
dtype=[('left_child', '<i8'), ('right_child', '<i8'),
('feature', '<i8'), ('threshold', '<f8'),
('impurity', '<f8'), ('n_node_samples', '<i8'),
('weighted_n_node_samples', '<f8')])

地点:

  • 第一个节点[0]是根节点。
  • 内部节点有 left_child 和 right_child 指的是具有正值且大于当前节点的节点。
  • 叶子的左右子节点的值为-1。
  • 节点 1,5,6, 8,10,11,14,15,16 是叶子。
  • 节点结构是使用深度优先搜索算法构建的。
  • 特征字段告诉我们在节点中使用了哪些 iris.data 特征来确定该样本的路径。
  • 阈值告诉我们用于根据特征评估方向的值。
  • 杂质在叶子处达到 0...因为一旦到达叶子,所有样本都属于同一类。
  • n_node_samples 告诉我们有多少样本到达每片叶子。

使用此信息,我们可以通过遵循脚本上的分类规则和阈值,轻松地将每个样本 X 跟踪到它最终到达的叶子。此外,n_node_samples 将允许我们执行单元测试以确保每个节点获得正确数量的样本。然后使用 tree.predict 的输出,我们可以将每个叶子映射到关联的类。

关于python - 如何探索使用 scikit learn 构建的决策树,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/32506951/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com