gpt4 book ai didi

python - Python 中的二维梯度下降

转载 作者:行者123 更新时间:2023-11-30 09:10:40 27 4
gpt4 key购买 nike

我无法理解二维梯度下降。假设我有功能 f(x,y)=x**2-xy哪里df/dx = 2x-ydf/dy = -x .

因此对于点 df(2,3),输出向量为 [1, -2].T。向量 [1,-2] 指向的位置都是最陡上升的方向(也称为 f(x,y) 的输出)。 我应该选择一个固定的步长,并找到该大小的步长使 f(x,y) 增加最多的方向。如果我想下降,我想找到-f(x,y)增加最快的方向?

如果我的直觉是正确的,你会如何编写这个代码?假设我从点 (x=0, y=5) 开始,我想执行梯度下降来找到最小值。

step_size = 0.01
precision = 0.00001 #stopping point
enter code here??

最佳答案

这是使用 matplotlib 可视化实现梯度下降:

import csv
import math
def loadCsv(filename):
lines = csv.reader(open(filename, "r"))
dataset = list(lines)
for i in range(len(dataset)):
dataset[i] = [float(x) for x in dataset[i]]
return dataset

def h(o1,o2,x):
ans=o1+o2*x
return ans

def costf(massiv,p1,p2):
sum1=0.0
sum2=0.0
for x,y in massiv:
sum1+=(math.pow(h(o1,o2,x)-y,2))
sum2=(1.0/(2*len(massiv)))*sum1
return sum1,sum2

def gradient(massiv,er,alpha,o1,o2,max_loop=1000):
i=0
J,e=costf(massiv,o1,o2)
conv=False
m=len(massiv)
while conv!=True:
sum1=0.0
sum2=0.0
for x,y in massiv:
sum1+=(o1+o2*x-y)
sum2+=(o1+o2*x-y)*x
grad0=1.0/m*sum1
grad1=1.0/m*sum2

temp0=o1-alpha*grad0
temp1=o2-alpha*grad1
print(temp0,temp1)
o1=temp0
o2=temp1
e=0.0
for x,y in massiv:
e+=(math.pow(h(o1,o2,x)-y,2))
if abs(J-e)<=ep:
print('Successful\n')
conv=True

J=e

i+=1
if i>=max_loop:
print('Too much\n')
break
return o1,o2


#data = massiv
data=loadCsv('ex1data1.txt')
o1=0.0 #temp0=0
o2=1.0 #temp1=1
alpha=0.01
ep=0.01
t0,t1=gradient(data,ep,alpha,o1,o2)
print('temp0='+str(t0)+' \ntemp1='+str(t1))

x=35000
while x<=70000:
y=h(t0,t1,x)
print('x='+str(x)+'\ny='+str(y)+'\n')
x+=5000

maxx=data[0][0]
for q,w in data:
maxx=max(maxx,q)
maxx=round(maxx)+1
line=[]
ll=0
while ll<maxx:
line.append(h(t0,t1,ll))
ll+=1
x=[]
y=[]
for q,w in data:
x.append(q)
y.append(w)

import matplotlib.pyplot as plt
plt.plot(x,y,'ro',line)
plt.ylabel('some numbers')
plt.show()
<小时/>

Matplotlib 输出:

enter image description here

ex1data1.txt可以从这里下载: ex1data1.txt

可以使用 Python 3.5 在 Anaconda 发行版中按原样执行代码。

关于python - Python 中的二维梯度下降,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39681883/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com