gpt4 book ai didi

python - 如何使用 sklearn 的 SGDClassifier 返回前 N 个预测的准确率?

转载 作者:太空宇宙 更新时间:2023-11-04 02:05:27 26 4
gpt4 key购买 nike

我正在尝试修改这篇文章(How to get Top 3 or Top N predictions using sklearn's SGDClassifier)中的结果以获得返回的准确率,但是我得到的准确率为零而且我无法弄清楚为什么。有什么想法吗?任何想法/编辑将不胜感激!谢谢。

from sklearn.feature_extraction.text import TfidfVectorizer
import numpy as np
from sklearn import linear_model
arr=['dogs cats lions','apple pineapple orange','water fire earth air', 'sodium potassium calcium']
vectorizer = TfidfVectorizer()
X = vectorizer.fit_transform(arr)
feature_names = vectorizer.get_feature_names()
Y = ['animals', 'fruits', 'elements','chemicals']
T=["eating apple roasted in fire and enjoying fresh air"]
test = vectorizer.transform(T)
clf = linear_model.SGDClassifier(loss='log')
clf.fit(X,Y)
x=clf.predict(test)

def top_n_accuracy(probs, test, n):
best_n = np.argsort(probs, axis=1)[:,-n:]
ts = np.argmax(test, axis=1)
successes = 0
for i in range(ts.shape[0]):
if ts[i] in best_n[i,:]:
successes += 1
return float(successes)/ts.shape[0]

n=2
probs = clf.predict_proba(test)

top_n_accuracy(probs, test, n)

最佳答案

from sklearn.feature_extraction.text import TfidfVectorizer
import numpy as np
from sklearn import linear_model

arr=['dogs cats lions','apple pineapple orange','water fire earth air', 'sodium potassium calcium']
vectorizer = TfidfVectorizer()
X = vectorizer.fit_transform(arr)
feature_names = vectorizer.get_feature_names()
Y = ['animals', 'fruits', 'elements','chemicals']
T=["eating apple roasted in fire and enjoying fresh air", "I love orange"]
test = vectorizer.transform(T)

clf = linear_model.SGDClassifier(loss='log')
clf.fit(X,Y)
x=clf.predict(test)

n=2
probs = clf.predict_proba(test)

topn = np.argsort(probs, axis = 1)[:,-n:]

这里我介绍了地面真值标签向量(这些是数字索引,你需要将 [“元素”等] 映射到 [0,1,2 等]。这里我假设你的测试示例属于元素。

y_true = np.array([2,1])

然后这应该计算你的准确度

np.mean(np.array([1 if y_true[k] in topn[k] else 0 for k in range(len(topn))]))

关于python - 如何使用 sklearn 的 SGDClassifier 返回前 N 个预测的准确率?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54855656/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com