gpt4 book ai didi

machine-learning - 值错误 : 'mean_squared_error' is not a valid scoring value

转载 作者:行者123 更新时间:2023-12-05 08:49:41 25 4
gpt4 key购买 nike

因此,我一直致力于我的第一个 ML 项目,作为该项目的一部分,我一直在尝试 sci-kit learn 中的各种模型,并为随机森林模型编写了这段代码:

#Random Forest
reg = RandomForestRegressor(random_state=0, criterion = 'mse')
#Apply grid search for best parameters
params = {'randomforestregressor__n_estimators' : range(100, 500, 200),
'randomforestregressor__min_samples_split' : range(2, 10, 3)}
pipe = make_pipeline(reg)
grid = GridSearchCV(pipe, param_grid = params, scoring='mean_squared_error', n_jobs=-1, iid=False, cv=5)
reg = grid.fit(X_train, y_train)
print('Best MSE: ', grid.best_score_)
print('Best Parameters: ', grid.best_estimator_)

y_train_pred = reg.predict(X_train)
y_test_pred = reg.predict(X_test)
tr_err = mean_squared_error(y_train_pred, y_train)
ts_err = mean_squared_error(y_test_pred, y_test)
print(tr_err, ts_err)
results_train['random_forest'] = tr_err
results_test['random_forest'] = ts_err

但是,当我运行这段代码时,出现以下错误:

KeyError                                  Traceback (most recent call last)
~\anaconda3\lib\site-packages\sklearn\metrics\_scorer.py in get_scorer(scoring)
359 else:
--> 360 scorer = SCORERS[scoring]
361 except KeyError:

KeyError: 'mean_squared_error'

During handling of the above exception, another exception occurred:

ValueError Traceback (most recent call last)
<ipython-input-149-394cd9e0c273> in <module>
5 pipe = make_pipeline(reg)
6 grid = GridSearchCV(pipe, param_grid = params, scoring='mean_squared_error', n_jobs=-1, iid=False, cv=5)
----> 7 reg = grid.fit(X_train, y_train)
8 print('Best MSE: ', grid.best_score_)
9 print('Best Parameters: ', grid.best_estimator_)

~\anaconda3\lib\site-packages\sklearn\utils\validation.py in inner_f(*args, **kwargs)
71 FutureWarning)
72 kwargs.update({k: arg for k, arg in zip(sig.parameters, args)})
---> 73 return f(**kwargs)
74 return inner_f
75

~\anaconda3\lib\site-packages\sklearn\model_selection\_search.py in fit(self, X, y, groups, **fit_params)
652 cv = check_cv(self.cv, y, classifier=is_classifier(estimator))
653
--> 654 scorers, self.multimetric_ = _check_multimetric_scoring(
655 self.estimator, scoring=self.scoring)
656

~\anaconda3\lib\site-packages\sklearn\metrics\_scorer.py in _check_multimetric_scoring(estimator, scoring)
473 if callable(scoring) or scoring is None or isinstance(scoring,
474 str):
--> 475 scorers = {"score": check_scoring(estimator, scoring=scoring)}
476 return scorers, False
477 else:

~\anaconda3\lib\site-packages\sklearn\utils\validation.py in inner_f(*args, **kwargs)
71 FutureWarning)
72 kwargs.update({k: arg for k, arg in zip(sig.parameters, args)})
---> 73 return f(**kwargs)
74 return inner_f
75

~\anaconda3\lib\site-packages\sklearn\metrics\_scorer.py in check_scoring(estimator, scoring, allow_none)
403 "'fit' method, %r was passed" % estimator)
404 if isinstance(scoring, str):
--> 405 return get_scorer(scoring)
406 elif callable(scoring):
407 # Heuristic to ensure user has not passed a metric

~\anaconda3\lib\site-packages\sklearn\metrics\_scorer.py in get_scorer(scoring)
360 scorer = SCORERS[scoring]
361 except KeyError:
--> 362 raise ValueError('%r is not a valid scoring value. '
363 'Use sorted(sklearn.metrics.SCORERS.keys()) '
364 'to get valid options.' % scoring)

ValueError: 'mean_squared_error' is not a valid scoring value. Use sorted(sklearn.metrics.SCORERS.keys()) to get valid options.

因此,我尝试通过从 GridSearchCV(pipe, param_grid = params, scoring='mean_squared_error', n_jobs=-1, iid=False , cv=5)。当我这样做时,代码运行完美,并给出了足够好的训练和测试错误。

无论如何,我无法弄清楚为什么 GridSearchCV 函数中的 scoring='mean_squared_error' 参数会抛出该错误。我做错了什么?

最佳答案

根据documentation :

All scorer objects follow the convention that higher return values are better than lower return values. Thus metrics which measure the distance between the model and the data, like metrics.mean_squared_error, are available as neg_mean_squared_error which return the negated value of the metric.

这意味着您必须传递 scoring='neg_mean_squared_error' 才能使用均方误差评估网格搜索结果。

关于machine-learning - 值错误 : 'mean_squared_error' is not a valid scoring value,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63473565/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com