gpt4 book ai didi

python - 分层KFold : IndexError: too many indices for array

转载 作者:太空狗 更新时间:2023-10-29 20:47:14 25 4
gpt4 key购买 nike

使用 sklearn 的 StratifiedKFold 函数,有人可以帮助我理解这里的错误吗?

我的猜测是它与我输入的标签数组有关,我注意到当我打印它们时(本例中的前 16 个)索引从 0 到 15,但是在我上面打印了一个额外的 0没想到。也许我只是一个 python 菜鸟,但这看起来很奇怪。

有人看到这里的错误吗?

文档:http://scikit-learn.org...StratifiedKFold.html

代码:

import nltk
import sklearn

print('The nltk version is {}.'.format(nltk.__version__))
print('The scikit-learn version is {}.'.format(sklearn.__version__))

print type(skew_gendata_targets.values), skew_gendata_targets.values.shape
print skew_gendata_targets.head(16)

skew_sfold10 = cross_validation.StratifiedKFold(skew_gendata_targets.values, n_folds=10, shuffle=True, random_state=20160121)

结果

The nltk version is 3.1.
The scikit-learn version is 0.17.
<type 'numpy.ndarray'> (500L, 1L)
0
0 0
1 0
2 0
3 0
4 0
5 0
6 0
7 0
8 0
9 0
10 0
11 0
12 0
13 0
14 1
15 0
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-373-653b6010b806> in <module>()
8 print skew_gendata_targets.head(16)
9
---> 10 skew_sfold10 = cross_validation.StratifiedKFold(skew_gendata_targets.values, n_folds=10, shuffle=True, random_state=20160121)
11
12 #print '\nSkewed Generated Dataset (', len(skew_gendata_data), ')'

d:\Program Files\Anaconda2\lib\site-packages\sklearn\cross_validation.pyc in __init__(self, y, n_folds, shuffle, random_state)
531 for test_fold_idx, per_label_splits in enumerate(zip(*per_label_cvs)):
532 for label, (_, test_split) in zip(unique_labels, per_label_splits):
--> 533 label_test_folds = test_folds[y == label]
534 # the test split can be too big because we used
535 # KFold(max(c, self.n_folds), self.n_folds) instead of

IndexError: too many indices for array

最佳答案

检查 skew_gendata_targets.values 的形状。您会看到它不是 StratifiedKFold 期望的一维数组(形状 (500,) ),而是 (500,1) 数组。 SKlearn 将它们分开处理,而不是强制它们相同。让我知道是否有帮助

关于python - 分层KFold : IndexError: too many indices for array,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/35022463/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com