gpt4 book ai didi

python - 如何获取 scikit-learn 决策树所有节点的 pos/neg 实例计数?

转载 作者:行者123 更新时间:2023-12-01 04:34:23 25 4
gpt4 key购买 nike

我训练了一个 sklearn 决策树。

from sklearn.tree import DecisionTreeClassifier
c=DecisionTreeClassifier(class_weight="auto")
c.fit([[0,0],
[0,1],
[1,1],
],[0,1,0])

现在我想检查每个节点有多少正/负样本。因此,像这样的图表

  counts: [2,1]            labels: (010)
split by x0
[1,1] [1,0] (01) (0)
split by x1
[1,0] [0,1] 0 (0) (1)
0 1

如何从经过训练的决策树中获取此(剩余计数)?

我可以看到一个c.tree_变量,但内容似乎不是很有帮助。有零、权重……并且很难猜测如何取回计数。

最佳答案

每个类的样本数量存储在tree_.value中,但它只存储叶子的节点值,因此我使用后序遍历来获取所有节点的值。

import numpy as np

def get_value(dt):
left = dt.tree_.children_left
right = dt.tree_.children_right
value = dt.tree_.value
leaves = np.argwhere(left == -1)[:, 0]

def visit(node):
if node in leaves:
return
visit(left[node])
visit(right[node])
value[node, :] = value[left[node], :] + value[right[node], :]

visit(0)
return value

例如,

from sklearn.tree import DecisionTreeClassifier
dt = DecisionTreeClassifier()
dt.fit([[0,0],
[0,1],
[1,1]], [0,1,0])
get_value(dt)

输出:

[[[ 2.  1.]]

[[ 1. 1.]]

[[ 1. 0.]]

[[ 0. 1.]]

[[ 1. 0.]]]

更新#1

我想知道为什么tree_.value只存储叶节点的值,然后我发现https://stackoverflow.com/questions/27417809/show-values-at-each-node-level-of-scikit-learn-decision-treethis issue .

事实证明,在 scikit-learn 0.17.dev0 中,tree_.value 已经返回所有节点的值。

In [1]: from sklearn.tree import DecisionTreeClassifier

In [2]: dt = DecisionTreeClassifier()

In [3]: dt.fit([[0,0],
...: [0,1],
...: [1,1]], [0,1,0])
Out[3]:
DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
max_features=None, max_leaf_nodes=None, min_samples_leaf=1,
min_samples_split=2, min_weight_fraction_leaf=0.0,
random_state=None, splitter='best')

In [4]: dt.tree_.value
Out[4]:
array([[[ 2., 1.]],

[[ 1., 1.]],

[[ 1., 0.]],

[[ 0., 1.]],

[[ 1., 0.]]])

更新#2

尽管我认为在给出 class_weight 时“取消权重”确实没有有意义,但这是可以实现的。

class_weight 的计算方式为

In [1]: from sklearn.utils import compute_class_weight

In [2]: compute_class_weight('auto', [0, 1], [0, 1, 0])
Out[2]: array([ 0.66666667, 1.33333333])

因此可以在if node in leaves:后面添加value[node, :]/= class_weight来重新计算叶子节点的值。

关于python - 如何获取 scikit-learn 决策树所有节点的 pos/neg 实例计数?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/31991530/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com