gpt4 book ai didi

python-3.x - python中简单实现线性回归时出现溢出错误

转载 作者:行者123 更新时间:2023-11-30 08:59:13 25 4
gpt4 key购买 nike

我在运行代码时遇到此错误:

ErrorValue += ((m*x + b) - y)**2 RuntimeWarning: double_scalars 中遇到溢出

有人可以解释一下我的代码有什么问题吗?如果可以对我的线性reg尝试的正确性以及改进代码的方法提出有用的建议,那就太好了。

非常感谢!

import csv
import matplotlib.pyplot as plt
from numpy import *

'''
This is a simple implementation of linear regression on correlation
hours studied by student and the marks they obtained.
'''
def run():

points = genfromtxt("data.csv", delimiter=",")

# x is hours studied, y is marks obtained.

# We are applying the function: y = b + mx
for i in range(len(points)):
x = points[i][0]
y = points[i][1]

N = len(points)
b = 0
m = 0
alpha = 0.001 # alpha is the learning rate
ErrorThreshold = 0.003
NumberOfIterations = 1000 # We cancel the gradient descent after a number of iterations, if it still doesn't reach the threshold we want.

sum_m = 0
sum_b = 0

for i in range(NumberOfIterations):
while mean_squared_error(x,y,b,m,points) > ErrorThreshold:
b , m = gradient_descent(m,b,alpha,x,y,N,points)

def mean_squared_error(x,y,b,m,points):
ErrorValue = 0
for i in range(len(points)):
ErrorValue += ((m*x + b) - y)**2
return ErrorValue / len(points)


def gradient_descent(m,b,alpha,N,x,y,points):

#dealing with summation sign in gradient descent
sum_m = 0
sum_b = 0

for i in range(len(points)):
x = points[i][0]
y = points[i][1]
sum_m += m*x + b - y
sum_b += m*x + b - y
#repeating just for clarification purposes.

new_b = b - (2/N)*sum_b
new_m = m - (((2*m)/N))*sum_m

return new_b, new_m

if __name__ == '__main__':
run()

最佳答案

当您跟踪(打印)中间值时,您会看到什么?例如,在您的例程中添加几行。

def mean_squared_error(x,y,b,m,points):
print("ENTER", x, y, b, m, len(points))
ErrorValue = 0
for i in range(len(points)):
ErrorValue += ((m*x + b) - y)**2
print("TRACE", i, ErrorValue)
return ErrorValue / len(points)

另外,我不确定你的计算是否正确;我想你可能想要

        ErrorValue += ((m*x[i] + b) - y[i])**2

目前,您正在乘以整个向量,而不仅仅是标量,但您正在执行 len(points) 次。

最后,使用len(x)并且根本不传入points会更容易吗?除了长度之外,您没有使用

关于python-3.x - python中简单实现线性回归时出现溢出错误,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47060561/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com