作者热门文章
- html - 出于某种原因,IE8 对我的 Sass 文件中继承的 html5 CSS 不友好?
- JMeter 在响应断言中使用 span 标签的问题
- html - 在 :hover and :active? 上具有不同效果的 CSS 动画
- html - 相对于居中的 html 内容固定的 CSS 重复背景?
我正在研究Assignment 3: Regularization 。看过之后Github ,我尝试自己解决作业,但出现运行时错误。请注意,我选择了比链接更小的数据集。
情况是这样的:
print('Training set', train_dataset.shape, train_labels.shape)
print('Validation set', valid_dataset.shape, valid_labels.shape)
print('Test set', test_dataset.shape, test_labels.shape)
#Training set (20000, 784) (20000, 10)
#Validation set (1000, 784) (1000, 10)
#Test set (1000, 784) (1000, 10)
问题是这样的:
from sklearn.linear_model import LogisticRegression
original_train_labels = train_labels
logit_clf = LogisticRegression(penalty='l2')
logit_clf.fit(train_dataset[:1000,:], original_train_labels[:1000])
运行时,给出:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-12-4888dc0bbc75> in <module>()
4
5 logit_clf = LogisticRegression(penalty='l2')
----> 6 logit_clf.fit(train_dataset[:1000,:], original_train_labels[:1000])
7 predicted = logit_clf.predict(test_dataset)
8 print('accuracy', accuracy((np.arange(num_labels) == predicted[:,None]).astype(np.float32), test_labels), '%')
/usr/local/lib/python2.7/dist-packages/sklearn/linear_model/logistic.pyc in fit(self, X, y, sample_weight)
1140
1141 X, y = check_X_y(X, y, accept_sparse='csr', dtype=np.float64,
-> 1142 order="C")
1143 check_classification_targets(y)
1144 self.classes_ = np.unique(y)
/usr/local/lib/python2.7/dist-packages/sklearn/utils/validation.pyc in check_X_y(X, y, accept_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, multi_output, ensure_min_samples, ensure_min_features, y_numeric, warn_on_dtype, estimator)
513 dtype=None)
514 else:
--> 515 y = column_or_1d(y, warn=True)
516 _assert_all_finite(y)
517 if y_numeric and y.dtype.kind == 'O':
/usr/local/lib/python2.7/dist-packages/sklearn/utils/validation.pyc in column_or_1d(y, warn)
549 return np.ravel(y)
550
--> 551 raise ValueError("bad input shape {0}".format(shape))
552
553
ValueError: bad input shape (1000, 10)
知道如何解决这个问题吗?
最佳答案
您对 train_labels 使用 one-hot 编码。这意味着它的形状类似于 [1000. 10],1000 个样本,每个样本有 10 个“列”,其中 1 表示我们正在讨论哪个类。它是神经网络所必需的,但来自 sklearn 的逻辑回归 requires它的形状为 [1000, 1],这意味着它应该只是一个 1000 行的向量,并且在每一行中应该有一个表示目标类的 int 。使用 argmax 函数将 one-hot 编码转换为整数,您应该已准备就绪。
关于python - 优达学城 : Assignment 3: ValueError: bad input shape (1000, 10),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/38297894/
我是一名优秀的程序员,十分优秀!