gpt4 book ai didi

python - Scikit-learn 获取样本属于某个类别的概率

转载 作者:行者123 更新时间:2023-11-30 09:21:32 25 4
gpt4 key购买 nike

您好,我正在尝试将文本分为 4 个类别,我想打印预测以及文本属于每个类别的概率。
阅读 Scikit-learn 的文档后,我认为我应该使用 predict_proba,到目前为止我的代码是这样的:

# -*- coding: utf-8 -*-
#!/usr/bin/env python
import sys
import os
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import Pipeline
from sklearn.metrics import confusion_matrix, f1_score
from sklearn.datasets import load_files
from sklearn.svm import SVC
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.naive_bayes import MultinomialNB
from sklearn.metrics import accuracy_score
from sklearn.metrics import classification_report

string = sys.argv[1] #i will pass text to predict from console
sets = load_files('scikit') #load training set




count_vect = CountVectorizer(analyzer='char_wb', ngram_range=(0, 3), min_df=1)
X_train_counts = count_vect.fit_transform(sets.data)


tf_transformer = TfidfTransformer(use_idf=False).fit(X_train_counts)
X_train_tf = tf_transformer.transform(X_train_counts)


tfidf_transformer = TfidfTransformer()
X_train_tfidf = tfidf_transformer.fit_transform(X_train_counts)



clf = MultinomialNB().fit(X_train_tfidf, sets.target)
docs_new = [string]
X_new_counts = count_vect.transform(docs_new)
X_new_tfidf = tfidf_transformer.transform(X_new_counts)
predicted = clf.predict(X_new_tfidf)
for doc, category in zip(docs_new, predicted):
print('%r => %s' % (doc, sets.target_names[category])) #print prediction , and it is correct
print(clf.predict_proba(sets.target_names)) #trying to get prob for al classes

遗憾的是,输出是这样的:ValueError:对象未对齐,我尝试了很多不同的方法来实现此目的,并在网络上进行了很多搜索,但似乎都不起作用。任何建议将不胜感激。谢谢尼科。

最佳答案

predict_proba() 函数的输入应该与您提供给predict() 方法的输入完全相同。因此,您将获得概率

clf.predict_proba(X_new_tfidf)

关于python - Scikit-learn 获取样本属于某个类别的概率,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/35712828/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com