gpt4 book ai didi

machine-learning - 小批量 SGD 梯度计算 - 平均值或总和

转载 作者:行者123 更新时间:2023-11-30 08:51:12 25 4
gpt4 key购买 nike

我试图了解使用 miinibatch SGD 时如何计算梯度。我已经在 CS231 在线类(class)中实现了它,但后来才意识到,在中间层中,梯度基本上是为每个样本计算的所有梯度的总和(与 Ca​​ffe 或 Tensorflow 中的实现相同)。仅在最后一层(损失)中,它们才按样本数量进行平均。它是否正确?如果是这样,是否意味着由于在最后一层中它们被平均,所以在进行反向传播时,所有梯度也会自动平均?谢谢!

最佳答案

最好先了解 SGD 为何有效。

通常,神经网络实际上是一个非常复杂的复合函数,由输入向量 x、标签 y(或目标变量,根据问题是分类还是回归而变化)和一些参数向量 w 组成。假设我们正在进行分类。我们实际上正在尝试对变量向量 w 进行最大似然估计(实际上是 MAP 估计,因为我们肯定会使用 L2 或 L1 正则化,但目前技术性太强)。假设样本是独立的;那么我们有以下成本函数:

p(y1|w,x1)p(y2|w,x2) ... p(yN|w,xN)

将 wrt 优化为 w 是一团糟,因为所有这些概率都相乘(这将产生 wrt w 的极其复杂的导数)。我们使用对数概率代替(取对数不会改变极值点,并且我们除以 N,因此我们可以将训练集视为经验概率分布 p(x))

J(X,Y,w)=-(1/N)(log p(y1|w,x1) + log p(y2|w,x2) + ... + log p(yN|w,xN))

这是我们实际的成本函数。神经网络实际上所做的是对概率函数 p(yi|w,xi) 进行建模。这可以是一个非常复杂的 1000 多个分层 ResNet,也可以只是一个简单的感知器。

现在 w 的导数很容易表述,因为我们现在有了一个加法:

dJ(X,Y,w)/dw = -(1/N)(dlog p(y1|w,x1)/dw + dlog p(y2|w,x2)/dw + ... + dlog p(yN|w,xN)/dw)

理想情况下,上面是实际的渐变。但这种批量计算并不容易计算。如果我们正在处理包含 100 万个训练样本的数据集怎么办?更糟糕的是,训练集可能是样本 x 的流,其大小无限。

SGD 的随机部分在这里发挥作用。从训练集中随机均匀地选取 m << N 的 m 个样本,并使用它们计算导数:

 dJ(X,Y,w)/dw =(approx) dJ'/dw = -(1/m)(dlog p(y1|w,x1)/dw + dlog p(y2|w,x2)/dw + ... + dlog p(ym|w,xm)/dw)

请记住,我们有一个经验(或无限训练集情况下的实际)数据分布 p(x)。上述从 p(x) 中抽取 m 个样本并对它们进行平均的操作实际上产生了实际导数 dJ(X,Y,w)/dw 的无偏估计量 dJ'/dw。这意味着什么?采取许多这样的 m 个样本并计算不同的 dJ'/dw 估计值,对它们进行平均,然后在无限采样的限制下,您会得到非常接近甚至精确的 dJ(X,Y,w)/dw 。可以证明,从长远来看,这些有噪声但无偏的梯度估计将表现得像原始梯度一样。平均而言,SGD 将遵循实际梯度的路径(但它可能会陷入不同的局部最小值,这一切都取决于学习率的选择)。小批量大小 m 与噪声估计 dJ'/dw 中的固有误差直接相关。如果 m 很大,你会得到低方差的梯度估计,你可以使用更大的学习率。如果m很小或者m=1(在线学习),估计器dJ'/dw的方差非常高,你应该使用较小的学习率,否则算法很容易发散失控。

现在理论已经足够了,你的实际问题是

It is only in the last layer (the loss) that they are averaged by the number of samples. Is this correct? if so, does it mean that since in the last layer they are averaged, when doing backprop, all the gradients are also averaged automatically? Thanks!

是的,在最后一层除以 m 就足够了,因为一旦最下层乘以因子 (1/m),链式法则就会将因子 (1/m) 传播到所有参数。您不需要为每个参数单独执行,这将是无效的。

关于machine-learning - 小批量 SGD 梯度计算 - 平均值或总和,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41145831/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com