gpt4 book ai didi

python - 在 Python 中实现线性回归

转载 作者:行者123 更新时间:2023-11-30 09:09:01 25 4
gpt4 key购买 nike

我刚刚开始使用 Siraj Raval 在 youtube 上的视频进行机器学习,并尝试了视频“Intro - The Math of Intelligence”的挑战,即使用 kaggle.com 的数据集使用梯度下降执行线性回归。这是我的代码:

"""
An Example of a Linear Regression model.

Here i am taking an example from https://www.kaggle.com/alopez247/pokemon
to find a relation between variable "Total" and "HP".

"""
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import sys
import os

data = pd.read_csv("./pokemon_alopez247.csv")
d = {"Total": data['Total'],
"HP": data['HP']}
smallData = pd.DataFrame(d)
test = smallData.values
epsilon = 0.001


def compute_error_for_line(b, m, points):
"""Return the Error for Line given the points."""
totalError = 0
for i in range(0, len(points)):
x = test[i, 0]
y = test[i, 1]
totalError += (y - (m * x + b)) ** 2
return totalError / float(len(points))


def step_gradient(b_current, m_current, points, learningRate):
"""Return the new b and m points."""
b_gradient = 0
m_gradient = 0
N = float(len(points))
for i in range(0, len(points)):
x = points[i, 0]
y = points[i, 1]
error = y - ((m_current * x) + b_current)
b_gradient += -(2 / N) * error
m_gradient += -(2 / N) * x * error
new_b = b_current - (learningRate * b_gradient)
new_m = m_current - (learningRate * m_gradient)
return [new_b, new_m]


def main():
"""Return and plot function here."""
plt.figure(num=None, figsize=(20, 10), dpi=80,
facecolor='w', edgecolor='k')
plt.axis([0, 780, 0, 260])
plt.ylabel("Total")
plt.xlabel("HP")
plt.scatter(test[:, [1]], test[:, [0]], c='r', s=1)

m = 0.3
b = -30
x = np.arange(800)
y = m * x + b
for i in range(30):
error = compute_error_for_line(b, m, test)
print("error :", error)
if(error > epsilon):
y = m * x + b
plt.plot(x, y)
b, m = step_gradient(b, m, test, 0.0001)
print("b , m :", b, ",", m)
plt.pause(0.01)

plt.show()

plt.pause(0.001)

if __name__ == '__main__':
try:
main()
except KeyboardInterrupt:
print('Interrupted')
try:
sys.exit(0)
except SystemExit:
os._exit(0)

输出是:

error : 193676.072288
b , m : -29.91451362 , 6.46934413315
/usr/local/lib/python3.5/dist-packages/matplotlib/backend_bases.py:2445: MatplotlibDeprecationWarning: Using default event loop until function specific to this GUI is implemented
warnings.warn(str, mplDeprecation)
error : 16427.2683093
b , m : -29.9134163218 , 6.04491523016
error : 15588.2873385
b , m : -29.9065147511 , 6.07401898958
error : 15583.8939554
b , m : -29.9000125838 , 6.07192788394
error : 15583.4489928
b , m : -29.8934831191 , 6.07198242461
error : 15583.0227312
b , m : -29.8869557061 , 6.07188938575
error : 15582.5965792
b , m : -29.8804283262 , 6.07180649992
error : 15582.1704489
b , m : -29.8739011182 , 6.07172291798
error : 15581.74434
b , m : -29.8673740726 , 6.07163938615
error : 15581.3182523
b , m : -29.86084719 , 6.0715558531
error : 15580.8921858
b , m : -29.8543204704 , 6.07147232236
error : 15580.4661407
b , m : -29.8477939138 , 6.0713887937
error : 15580.0401168
b , m : -29.8412675201 , 6.07130526712
error : 15579.6141143
b , m : -29.8347412894 , 6.07122174263
error : 15579.1881329
b , m : -29.8282152217 , 6.07113822022
error : 15578.7621729
b , m : -29.821689317 , 6.0710546999
error : 15578.3362341
b , m : -29.8151635752 , 6.07097118166
error : 15577.9103166
b , m : -29.8086379963 , 6.07088766551
error : 15577.4844204
b , m : -29.8021125804 , 6.07080415145
error : 15577.0585455
b , m : -29.7955873275 , 6.07072063947
error : 15576.6326918
b , m : -29.7890622375 , 6.07063712957
error : 15576.2068594
b , m : -29.7825373104 , 6.07055362176
error : 15575.7810482
b , m : -29.7760125462 , 6.07047011604
error : 15575.3552583
b , m : -29.769487945 , 6.0703866124
error : 15574.9294897
b , m : -29.7629635067 , 6.07030311084
error : 15574.5037423
b , m : -29.7564392314 , 6.07021961138
error : 15574.0780162
b , m : -29.7499151189 , 6.07013611399
error : 15573.6523114
b , m : -29.7433911694 , 6.07005261869
error : 15573.2266278
b , m : -29.7368673827 , 6.06996912548
error : 15572.8009655
b , m : -29.730343759 , 6.06988563435
[Finished in 73.209s]

因此输出表明一切都按计划进行。但看看this 。第一个蓝色是原始值,并且线越来越远!我尝试重写compute_error_for_line和step_gradient函数,但仍然没有任何结果。感谢您阅读到最后。

那么如何获得最适合我的样本空间的线参数呢?

链接到我的 csv 文件 here (此文件将在 22 小时后过期)。

最佳答案

    plt.scatter(test[:, [1]], test[:, [0]], c='r', s=1)

看起来你已经交换了 x 和 y 值。如果将 [1] 更改为 [0],反之亦然,则情节看起来相当不错

关于python - 在 Python 中实现线性回归,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45045355/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com