gpt4 book ai didi

python - Scikit Learn-MultinomialNB 用于文本分类

转载 作者:行者123 更新时间:2023-11-30 09:08:33 24 4
gpt4 key购买 nike

如何计算多类文本分类的 FPR、TPR、AUC、roc_curve。

我使用了以下代码:

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
from sklearn.feature_extraction.text import CountVectorizer
vect=CountVectorizer()
vect.fit(X_train.values.astype('U'))
X_train_dtm=vect.transform(X_train.values.astype('U'))
X_test_dtm=vect.transform(X_test)
from sklearn.naive_bayes import MultinomialNB
nb = MultinomialNB()
y_score=nb.fit(X_train_dtm, y_train)
y_pred_class = nb.predict(X_test_dtm)

到这里为止一切都运行良好。但一旦我使用以下代码,就会出现错误。

from sklearn.metrics import roc_curve, auc, roc_auc_score
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(5):
fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i])
roc_auc[i] = auc(fpr[i], tpr[i])
fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_score.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
print ("ROC value is:",roc_auc["micro"])

错误是:

Traceback (most recent call last):    

File "C:/Users/saurabh/PycharmProjects/getting_started/own_code.py", line 32, in <module>
print(metrics.roc_auc_score(y_test, y_pred_prob))

File "C:\Anaconda3\lib\site-packages\sklearn\metrics\ranking.py", line 260, in roc_auc_score
sample_weight=sample_weight) Accuracy by this: 0.910536779324

File "C:\Anaconda3\lib\site-packages\sklearn\metrics\base.py", line 81, in _average_binary_score
raise ValueError("{0} format is not supported".format(y_type))

ValueError: multiclass format is not supported

最佳答案

roc_curve 不支持多类格式。您必须计算二进制类。

但是要计算 FPR、TPR,您可以使用 confusion_matrix

from sklearn.metrics import confusion_matrix
y_test = np.argmax(y_test, axis=1)
y_score = np.argmax(y_score, axis=1)
c = confusion_matrix(y_test, y_score)
TNR = float(c[0][0])
TPR = float(c[1][1])
FNR = float(c[1][0])
FPR = float(c[0][1])

这是一个简单的二值化示例

for i in range(5):
yt_bin = [1 if x == i else 0 for x in y_test[:, i]]
fpr[i], tpr[i], _ = roc_curve(yt_bin, y_score[:, i])
roc_auc[i] = auc(fpr[i], tpr[i])

关于python - Scikit Learn-MultinomialNB 用于文本分类,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46014642/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com