gpt4 book ai didi

algorithm - 随机梯度下降收敛太平滑

转载 作者:塔克拉玛干 更新时间:2023-11-03 05:17:18 25 4
gpt4 key购买 nike

作为我家庭作业的一部分,我被要求实现随机梯度下降以解决线性回归问题(尽管我只有 200 个训练样本)。我的问题是随机梯度下降收敛得太平滑了,几乎和批量梯度下降完全一样,这让我想到了我的问题:考虑到通常它的噪声要大得多,为什么它看起来如此平滑。是因为我只用了 200 个例子吗?

收敛图:

Stochastic gradient descent

Gradient descent

具有随机梯度下降权重的 MSE:2.78441258841

具有梯度下降权重的 MSE:2.78412631451(与具有正规方程权重的 MSE 相同)

我的代码:

def mserror(y, y_pred):

n = y.size
diff = y - y_pred
diff_squared = diff ** 2
av_er = float(sum(diff_squared))/n

return av_er

.

def linear_prediction(X, w):
return dot(X,np.transpose(w))

.

def gradient_descent_step(X, y, w, eta):

n = X.shape[0]

grad = (2.0/n) * sum(np.transpose(X) * (linear_prediction(X,w) - y), axis = 1)

return w - eta * grad

.

def stochastic_gradient_step(X, y, w, train_ind, eta):

n = X.shape[0]

grad = (2.0/n) * np.transpose(X[train_ind]) * (linear_prediction(X[train_ind],w) - y[train_ind])

return w - eta * grad

.

def gradient_descent(X, y, w_init, eta, max_iter):

w = w_init
errors = []
errors.append(mserror(y, linear_prediction(X,w)))

for i in range(max_iter):
w = gradient_descent_step(X, y, w, eta)
errors.append(mserror(y, linear_prediction(X,w)))

return w, errors

.

def stochastic_gradient_descent(X, y, w_init, eta, max_iter):

n = X.shape[0]
w = w_init

errors = []
errors.append(mserror(y, linear_prediction(X,w)))

for i in range(max_iter):

random_ind = np.random.randint(n)

w = stochastic_gradient_step(X, y, w, random_ind, eta)
errors.append(mserror(y, linear_prediction(X,w)))

return w, errors

最佳答案

您的图表没有任何异常。您还应该注意,您的批处理方法需要更少的迭代来收敛。

您可能会让来自神经网络的 SGD 图影响您对 SGD“应该”是什么样子的看法。大多数神经网络都是更复杂的模型(难以优化),用于解决更难的问题。这导致了您可能期望的“锯齿状”。

线性回归是一个简单的问题,并且有一个凸解。这意味着任何降低错误率的步骤都保证是朝着最佳解决方案迈出的一步。这比神经网络复杂得多,也是您看到错误减少平稳的部分原因。这也是您看到几乎相同的 MSE 的原因。 SGD 和批处理收敛到完全相同的解决方案。

如果你想尝试强制一些非平滑性,你可以不断增加学习率 eta,但这是一种愚蠢的做法。最终你会到达一个你不收敛的点,因为你总是越过解决方案。

关于algorithm - 随机梯度下降收敛太平滑,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42766970/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com