gpt4 book ai didi

python - Sklearn 和 StatsModels 给出了截然不同的逻辑回归答案

转载 作者:行者123 更新时间:2023-12-03 20:00:54 37 4
gpt4 key购买 nike

我正在对 bool 值 0/1 数据集进行逻辑回归(预测某个年龄给你一定数量的薪水的概率),我得到了 sklearn 和 StatsModels 非常不同的结果,其中 sklearn 是非常错误的。
我已将 sklearn 惩罚设置为 None 并将拦截项设置为 false 以使该函数更类似于 StatsModels,但我不知道如何让 sklearn 给出合理的答案。
灰线是 0 或 1 处的原始数据点,我只是将绘图上的 1 缩小到 0.1 以使其可见。
变量:

# X and Y
X = df.age.values.reshape(-1,1)
X_poly = PolynomialFeatures(degree=4).fit_transform(X)
y_bool = np.array(df.wage.values > 250, dtype = "int")

# Generate a sequence of ages
age_grid = np.arange(X.min(), X.max()).reshape(-1,1)
age_grid_poly = PolynomialFeatures(degree=4).fit_transform(age_grid)
代码如下:
# sklearn Model
clf = LogisticRegression(penalty = None, fit_intercept = False,max_iter = 300).fit(X=X_poly, y=y_bool)
preds = clf.predict_proba(age_grid_poly)

# Plot
fig, ax = plt.subplots(figsize=(8,6))
ax.scatter(X ,y_bool/10, s=30, c='grey', marker='|', alpha=0.7)
plt.plot(age_grid, preds[:,1], color = 'r', alpha = 1)
plt.xlabel('Age')
plt.ylabel('Wage')
plt.show()
sklearn result
# StatsModels
log_reg = sm.Logit(y_bool, X_poly).fit()
preds = log_reg.predict(age_grid_poly)
# Plot
fig, ax = plt.subplots(figsize=(8,6))
ax.scatter(X ,y_bool/10, s=30, c='grey', marker='|', alpha=0.7)
plt.plot(age_grid, preds, color = 'r', alpha = 1)
plt.xlabel('Age')
plt.ylabel('Wage')
plt.show()
StatsModels result

最佳答案

一旦我没有数据集或 scikit-learn 和 statsmodels 的特定版本,我就无法准确地重现结果。但是,我认为您无法成功删除代码中的正则化参数。文档指出您应该传递字符串 'none' ,不是常数 None .
请引用sklearn.linear_model.LogisticRegression文档:

penalty{‘l1’, ‘l2’, ‘elasticnet’, ‘none’}, default=’l2’ Used tospecify the norm used in the penalization. The ‘newton-cg’, ‘sag’ and‘lbfgs’ solvers support only l2 penalties. ‘elasticnet’ is onlysupported by the ‘saga’ solver. If ‘none’ (not supported by theliblinear solver), no regularization is applied.


我认为通过调查系数而不是使用绘图更容易理解差异。
您可以直接使用属性 coef_ 进行调查。对于 scikit-learn 模型和 params对于 statsmodels 模型。
从逻辑上讲,如果未正确禁用正则化参数,您应该期望 scikit-learn 模型中的系数较低。

关于python - Sklearn 和 StatsModels 给出了截然不同的逻辑回归答案,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66830858/

37 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com