gpt4 book ai didi

python-3.x - 线性回归中与一个特征的每次交互都会增加梯度下降成本

转载 作者:行者123 更新时间:2023-11-30 09:34:30 25 4
gpt4 key购买 nike

您好,我正在学习一些机器学习算法,为了理解,我试图实现一种线性回归算法,其一个特征使用梯度下降法的残差平方和作为成本函数,如下所示:

我的伪代码:

 while not converge
w <- w - step*gradient

Python代码线性.py

import math
import numpy as num

def get_regression_predictions(input_feature, intercept, slope):
predicted_output = [intercept + xi*slope for xi in input_feature]
return(predicted_output)

def rss(input_feature, output, intercept,slope):
return sum( [ ( output.iloc[i] - (intercept + slope*input_feature.iloc[i]) )**2 for i in range(len(output))])

def train(input_feature,output,intercept,slope):


file = open("train.csv","w")
file.write("ID,intercept,slope,RSS\n")
i =0

while True:

print("RSS:",rss(input_feature, output, intercept,slope))
file.write(str(i)+","+str(intercept)+","+str(slope)+","+str(rss(input_feature, output, intercept,slope))+"\n")
i+=1

gradient = [derivative(input_feature, output, intercept,slope,n) for n in range(0,2) ]

step = 0.05
intercept -= step*gradient[0]
slope-= step*gradient[1]
return intercept,slope


def derivative(input_feature, output, intercept,slope,n):
if n==0:
return sum( [ -2*(output.iloc[i] - (intercept + slope*input_feature.iloc[i])) for i in range(0,len(output))] )
return sum( [ -2*(output.iloc[i] - (intercept + slope*input_feature.iloc[i]))*input_feature.iloc[i] for i in range(0,len(output))] )

与主程序:

import Linear as lin
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split




df = pd.read_csv("test2.csv")


train = df

lin.train(train["X"],train["Y"], 0, 0)

test2.csv:

X,Y
0,1
1,3
2,7
3,13
4,21

我抵制了文件中 rss 的值,并注意到 rss 的值在每次迭代中变得最差,如下所示:

ID,intercept,slope,RSS
0,0,0,669
1,4.5,14.0,3585.25
2,-7.25,-18.5,19714.3125
3,19.375,58.25,108855.953125

从数学上讲,我认为这没有任何意义,我多次检查自己的代码,我认为它是正确的,我做错了什么?

最佳答案

如果您的成本没有下降,这通常表明您的梯度下降方法超出了范围,这意味着步长太大。

Large step size

较小的步长会有所帮助。您还可以研究可变步长的方法,它可以改变每次迭代以获得良好的收敛特性和速度;通常,这些方法改变步长与梯度成一定比例。当然,具体情况取决于每个问题。

关于python-3.x - 线性回归中与一个特征的每次交互都会增加梯度下降成本,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46569899/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com