gpt4 book ai didi

python - "Invalid Index to Scalar Variable"- 使用 Scikit Learn 时 "accuracy_score"

转载 作者:行者123 更新时间:2023-11-30 09:38:21 29 4
gpt4 key购买 nike

不确定到底出了什么问题。然而,我的目标是建立一个交叉验证的Python代码。我知道有多种衡量标准,但我认为我使用的是正确的衡量标准。我没有得到我想要的 CV10 结果,而是收到错误:

“标量变量索引无效”

我在 StackOverflow 上发现了这个:索引错误:当您尝试对 numpy 标量(例如 numpy.int64 或 numpy.float64)建立索引时,会发生标量变量的无效索引。它与 TypeError: 'int' object has no attribute '_getitem_' 当您尝试索引 int 时非常相似。

如有任何帮助,我们将不胜感激...

我正在尝试关注::http://scikit-learn.org/stable/modules/model_evaluation.html

from sklearn.ensemble import RandomForestClassifier
from sklearn import cross_validation
from numpy import genfromtxt
import numpy as np
from sklearn.metrics import accuracy_score

def main():
#read in data, parse into training and target sets
dataset = genfromtxt(open('D:\\CA_DataPrediction_TrainData\\CA_DataPrediction_TrainDataGenetic.csv','r'), delimiter=',', dtype='f8')[1:]
target = np.array( [x[0] for x in dataset] )
train = np.array( [x[1:] for x in dataset] )

#In this case we'll use a random forest, but this could be any classifier
cfr = RandomForestClassifier(n_estimators=10)

#Simple K-Fold cross validation. 10 folds.
cv = cross_validation.KFold(len(train), k=10, indices=False)

#iterate through the training and test cross validation segments and
#run the classifier on each one, aggregating the results into a list
results = []
for traincv, testcv in cv:
pred = cfr.fit(train[traincv], target[traincv]).predict(train[testcv])
results.append(accuracy_score(target[testcv], [x[1] for x in pred]) )

#print out the mean of the cross-validated results
print "Results: " + str( np.array(results).mean() )

if __name__=="__main__":
main()

最佳答案

您的 pred 变量只是一个预测列表,因此您无法为其元素建立索引(这就是错误的原因)

results.append(accuracy_score(target[testcv], [x[1] for x in pred]) )

应该是

results.append(accuracy_score(target[testcv], pred) )

或者如果你真的想要一份副本

results.append(accuracy_score(target[testcv], [x for x in pred]) )

关于python - "Invalid Index to Scalar Variable"- 使用 Scikit Learn 时 "accuracy_score",我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/19212845/

29 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com