gpt4 book ai didi

algorithm - 直观理解 Adam 优化器

转载 作者:行者123 更新时间:2023-12-02 22:44:28 25 4
gpt4 key购买 nike

根据Adam的伪代码:

enter image description here

我写了一些代码:

from matplotlib import pyplot as plt
import numpy as np
# np.random.seed(42)
num = 100

x = np.arange(num).tolist()
# The following 3 sets of g_list stand for 3 types of gradient changes:
# g_list = np.random.normal(0,1,num) # gradient direction changes frequently in positive and negtive
# g_list = x # gradient direction always positive and gradient value becomes larger gradually
g_list = [10 for _ in range(num)] # gradient direction always positive and gradient value always the same

m = 0
v = 0
beta_m = 0.9
beta_v = 0.999
m_list = []
v_list = []

for i in range(1,num+1):
g = g_list[i-1]
m = beta_m*m + (1 - beta_m)*g
m = m/(1-beta_m**i)
v = beta_v*v + (1 - beta_v)*(g**2)
v = v/(1-beta_v**i)
m_list.append(m)
v_list.append(np.sqrt(v))

mv = np.array(m_list)/(np.array(v_list) +0.001)
print("==>> mv: ", mv)
plt.plot(x, g_list, x, mv)

运行代码,得到如下图:

enter image description here

对我来说,我认为这是违反直觉的,因为我认为当梯度方向始终为正且梯度值恒定时,学习率的系数(即mv)应该接近1,但是我得到的第 100 个 mv 是 3.40488818e-70,几乎接近于零。

如果我更改一些代码:

    # m = m/(1-beta_m**i)
if i == 1:
m = m/(1-beta_m**i)
# v = v/(1-beta_v**i)
if i == 1:
v = v/(1-beta_v**i)

我得到的结果是这样的:

enter image description here

这比较符合我的直觉。

有人能告诉我我上面的代码是否正确吗?如果正确,它是否符合您的直觉来获得上面的代码?

最佳答案

您的代码实现差不多了,但是您应该注意您的实现和算法之间的一个区别是您错误地累积了偏差校正项 m/(1-beta_m**i)变量 m。您应该为偏差校正分配一个单独的变量 m_hat

这同样适用于v:将偏差校正值分配给另一个变量,如v_hat

这样做将避免在 mv 的累加中包含偏差校正。

您的代码可以保持不变,但更改偏差校正值的计算以及列表附加值。如果这样做,您将获得想要的结果。

for i in range(1,num+1):
g = g_list[i-1]

# calculate m and v
m = beta_m*m + (1 - beta_m)*g
v = beta_v*v + (1 - beta_v)*(g**2)

# assign bias corrected values to m_hat and v_hat respectively
m_hat = m/(1-beta_m**i)
v_hat = v/(1-beta_v**i)

# append to lists
m_list.append(m_hat)
v_list.append(np.sqrt(v_hat))

关于algorithm - 直观理解 Adam 优化器,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/71185171/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com