gpt4 book ai didi

python - 如何提取多类分类的随机森林树规则?

转载 作者:行者123 更新时间:2023-11-30 08:48:10 26 4
gpt4 key购买 nike

嗨,我想在多类分类的情况下从一棵树中提取规则

from sklearn.tree import _tree 
from sklearn.tree import DecisionTreeClassifier

#creat a gaussian classifier
clf=RandomForestClassifier(n_estimators=100)

#train the model using the training sets y_pred=clf.predict(X_test)

clf.fit(X_train,y_train)

#extract one tree from the forest
model = clf.estimators_[0]


def find_rules(tree,features):
dt=tree.tree_
def visitor(node,depth):
indent= ' ' * depth
if dt.feature[node] != _tree.TREE_UNDEFINED:
print('{} if <{}> <= {}:'.format(indent,features[node],round(dt.threshold[node],100)))
visitor(dt.children_left[node],depth+1)
print('{}else:'.format(indent))
visitor(dt.children_right[node],depth+1)
else:
print('{} return {}'.format(indent,dt.value[node]))
visitor(0,1)


find_rules(model, iris.feature_names)


enter image description here

最佳答案

请检查以下代码。似乎有效。只有一个小变化

def find_rules(tree,features): 
dt=tree.tree_
def visitor(node,depth):
indent= ' ' * depth
if dt.feature[node] != _tree.TREE_UNDEFINED:
print('{} if <{}> <= {}:'.format(indent,features[dt.feature[node]],round(dt.threshold[node],100)))
# in the previous line i added a backward-mapping
# for the feature id
visitor(dt.children_left[node],depth+1)
print('{} else:'.format(indent))
visitor(dt.children_right[node],depth+1)
else:
print('{} return {}'.format(indent,dt.value[node]))
visitor(0,1)

关于python - 如何提取多类分类的随机森林树规则?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56791341/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com