gpt4 book ai didi

python - SGDClassifier 在 MNIST 上的使用

转载 作者:行者123 更新时间:2023-11-30 09:35:06 25 4
gpt4 key购买 nike

我正在使用使用 scikit_Learn 进行机器学习实践”(O'Reilly)进行自学,目前正在使用具有不同分类器的 MNIST 数据。

第 94 页的文本表示,SGDClassifier 能够执行多类分类并使用 OvA 算法。当我尝试像这样拟合分类器时:

sgd_clf = SGDClassifier()
sgd_clf.fit(x_train, y_train)

我收到错误:

bad input shape (55000, 10).

这似乎与文本相矛盾。

<小时/>

数据信息

x_train.shape 为 55000x784,y_train.shape 为 55000x10,它们都是 numpy.ndarray

当我安装KNeighborsClassifier时,它工作得很好。

SGDClassifier 能否解决多类分类问题?

谢谢!

最佳答案

the documentation 中所述,

As other classifiers, SGD has to be fitted with two arrays: an array X of size [n_samples, n_features] holding the training samples, and an array Y of size [n_samples] holding the target values (class labels) for the training samples

这意味着 y 是一个由类标签组成的一维数组,如下例所示(取自上面的链接):

>>> from sklearn.linear_model import SGDClassifier
>>> X = [[0., 0.], [1., 1.]]
>>> y = [0, 1]
>>> clf = SGDClassifier(loss="hinge", penalty="l2")
>>> clf.fit(X, y)
SGDClassifier(alpha=0.0001, average=False, class_weight=None, epsilon=0.1,
eta0=0.0, fit_intercept=True, l1_ratio=0.15,
learning_rate='optimal', loss='hinge', n_iter=5, n_jobs=1,
penalty='l2', power_t=0.5, random_state=None, shuffle=True,
verbose=0, warm_start=False)

因此,您应该将 y 转换为由类标签(在您的情况下为 0-9)组成的向量。

关于python - SGDClassifier 在 MNIST 上的使用,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45284375/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com