gpt4 book ai didi

python - 如何在(GridSearchCV)拟合模型后打印估计系数? (SGDRegressor)

转载 作者:太空狗 更新时间:2023-10-29 21:26:36 26 4
gpt4 key购买 nike

我是 scikit-learn 的新手,但它满足了我的期望。现在,令人抓狂的是,唯一剩下的问题是我找不到如何打印(或者更好的是,写入一个小文本文件)它估计的所有系数,它选择的所有特征。有什么方法可以做到这一点?

与 SGDClassifier 相同,但我认为它对于所有可以适合的基础对象都是相同的,无论是否有交叉验证。完整脚本如下。

import scipy as sp
import numpy as np
import pandas as pd
import multiprocessing as mp
from sklearn import grid_search
from sklearn import cross_validation
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import SGDClassifier


def main():
print("Started.")
# n = 10**6
# notreatadapter = iopro.text_adapter('S:/data/controls/notreat.csv', parser='csv')
# X = notreatadapter[1:][0:n]
# y = notreatadapter[0][0:n]
notreatdata = pd.read_stata('S:/data/controls/notreat.dta')
notreatdata = notreatdata.iloc[:10000,:]
X = notreatdata.iloc[:,1:]
y = notreatdata.iloc[:,0]
n = y.shape[0]

print("Data lodaded.")
X_train, X_test, y_train, y_test = cross_validation.train_test_split(X, y, test_size=0.4, random_state=0)

print("Data split.")
scaler = StandardScaler()
scaler.fit(X_train) # Don't cheat - fit only on training data
X_train = scaler.transform(X_train)
X_test = scaler.transform(X_test) # apply same transformation to test data

print("Data scaled.")
# build a model
model = SGDClassifier(penalty='elasticnet',n_iter = np.ceil(10**6 / n),shuffle=True)
#model.fit(X,y)

print("CV starts.")
# run grid search
param_grid = [{'alpha' : 10.0**-np.arange(1,7),'l1_ratio':[.05, .15, .5, .7, .9, .95, .99, 1]}]
gs = grid_search.GridSearchCV(model,param_grid,n_jobs=8,verbose=1)
gs.fit(X_train, y_train)

print("Scores for alphas:")
print(gs.grid_scores_)
print("Best estimator:")
print(gs.best_estimator_)
print("Best score:")
print(gs.best_score_)
print("Best parameters:")
print(gs.best_params_)


if __name__=='__main__':
mp.freeze_support()
main()

最佳答案

装有最佳超参数的 SGDClassifier 实例存储在 gs.best_estimator_ 中。 coef_intercept_ 是该最佳模型的拟合参数。

关于python - 如何在(GridSearchCV)拟合模型后打印估计系数? (SGDRegressor),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/24375911/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com