gpt4 book ai didi

python-3.x - 使用 class_names 使用 graphviz 的树节点的颜色

转载 作者:行者123 更新时间:2023-12-03 16:05:36 25 4
gpt4 key购买 nike

扩展先前的问题:
Changing colors for decision tree plot created using export graphviz

我将如何根据主导类(鸢尾花的种类)而不是二元区分为树的节点着色?这应该需要 iris.target_names(描述​​类的字符串)和 iris.target(类)的组合。

import pydotplus
from sklearn.datasets import load_iris
from sklearn import tree
import collections

clf = tree.DecisionTreeClassifier(random_state=42)
iris = load_iris()

clf = clf.fit(iris.data, iris.target)

dot_data = tree.export_graphviz(clf, out_file=None,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True, rounded=True,
special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)
nodes = graph.get_node_list()
edges = graph.get_edge_list()

colors = ('brown', 'forestgreen')
edges = collections.defaultdict(list)

for edge in graph.get_edge_list():
edges[edge.get_source()].append(int(edge.get_destination()))

for edge in edges:
edges[edge].sort()
for i in range(2):
dest = graph.get_node(str(edges[edge][i]))[0]
dest.set_fillcolor(colors[i])

graph.write_png('tree.png')

最佳答案

示例中的代码看起来很熟悉,因此很容易修改:)

对于每个节点 Graphviz告诉我们每个组有多少样本,即它是混合种群还是树做出决定。我们可以提取此信息并用于获取颜色。

values = [int(ii) for ii in node.get_label().split('value = [')[1].split(']')[0].split(',')]

或者,您可以映射 GraphViz节点回 sklearn节点:
values = clf.tree_.value[int(node.get_name())][0]

我们只有 3 个类,所以每个类都有自己的颜色(红色、绿色、蓝色),混合种群根据它们的分布得到混合颜色。
values = [int(255 * v / sum(values)) for v in values]
color = '#{:02x}{:02x}{:02x}'.format(values[0], values[1], values[2])

enter image description here

我们现在可以很好地看到分离,绿色越多,我们拥有的第二类就越多,蓝色和第三类也是如此。
import pydotplus
from sklearn.datasets import load_iris
from sklearn import tree

clf = tree.DecisionTreeClassifier(random_state=42)
iris = load_iris()

clf = clf.fit(iris.data, iris.target)

dot_data = tree.export_graphviz(clf,
feature_names=iris.feature_names,
out_file=None,
filled=True,
rounded=True,
special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)
nodes = graph.get_node_list()

for node in nodes:
if node.get_label():
values = [int(ii) for ii in node.get_label().split('value = [')[1].split(']')[0].split(',')]
values = [int(255 * v / sum(values)) for v in values]
color = '#{:02x}{:02x}{:02x}'.format(values[0], values[1], values[2])
node.set_fillcolor(color)

graph.write_png('colored_tree.png')

超过 3 个类的通用解决方案,只为最终节点着色。
colors =  ('lightblue', 'lightyellow', 'forestgreen', 'lightred', 'white')

for node in nodes:
if node.get_name() not in ('node', 'edge'):
values = clf.tree_.value[int(node.get_name())][0]
#color only nodes where only one class is present
if max(values) == sum(values):
node.set_fillcolor(colors[numpy.argmax(values)])
#mixed nodes get the default color
else:
node.set_fillcolor(colors[-1])

enter image description here

关于python-3.x - 使用 class_names 使用 graphviz 的树节点的颜色,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43214350/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com