gpt4 book ai didi

python - 如何在 DecisionTreeClassifier 中设置类权重以进行多类设置

转载 作者:行者123 更新时间:2023-12-04 15:59:04 25 4
gpt4 key购买 nike

我正在使用 sklearn.tree.DecisionTreeClassifier 来训练 3-class 分类问题。

3个类的记录数如下:

A: 122038
B: 43626
C: 6678

当我训练分类器模型时,它无法学习类 - C。虽然效率是 65-70%,但它完全忽略了 C 类。

后来我知道了 class_weight 参数,但我不知道如何在多类设置中使用它。

这是我的代码:(我使用了 balanced 但它给出的准确度更差)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1)
clf = tree.DecisionTreeClassifier(criterion="gini", max_depth=3, random_state=1,class_weight='balanced')
clf = clf.fit(X_train,y_train)
y_pred = clf.predict(X_test)

如何使用与类分布成比例的权重。

其次,有没有更好的方法来解决这个不平衡类问题以提高准确性?

最佳答案

您还可以将值字典传递给 class_weight 参数以设置您自己的权重。例如,将 A 级的重量减半:

class_weight={
'A': 0.5,
'B': 1.0,
'C': 1.0
}

通过执行 class_weight='balanced',它会自动设置与类频率成反比的权重。

更多信息可以在 class_weight 参数下的文档中找到: https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html

通常可以预期平衡类会降低准确性。这就是为什么准确性通常被认为是不平衡数据集的一个较差指标的原因。

您可以尝试 sklearn 包含的 Balanced Accuracy 指标开始,但还有许多其他潜在指标可供尝试,这取决于您的最终目标是什么。

https://scikit-learn.org/stable/modules/model_evaluation.html

如果您不熟悉“混淆矩阵”及其相关值(例如精度和召回率),那么我会从那里开始您的研究。

https://en.wikipedia.org/wiki/Precision_and_recall

https://en.wikipedia.org/wiki/Confusion_matrix

https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html

关于python - 如何在 DecisionTreeClassifier 中设置类权重以进行多类设置,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62581004/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com