gpt4 book ai didi

machine-learning - Batchnorm 中的反向传播更新错误

转载 作者:行者123 更新时间:2023-11-30 09:18:01 25 4
gpt4 key购买 nike

我有这些 Backprop 更新,请让我知道 dx 部分哪里出了问题。在计算图中,我使用X、sample_mean 和sample_var。感谢您的帮助

(x, norm, sample_mean, sample_var, gamma, eps) = cache
dbeta = np.sum(dout, axis = 0)
dgamma = np.sum(dout * norm, axis = 0)
dxminus = dout * gamma / np.sqrt(sample_var + eps)
dmean = - np.sum(dxminus, axis = 0)
dxmean = np.full(x.shape, 1.0/x.shape[0]) * dmean
dvar = np.sum(dout * gamma * (x - sample_mean), axis = 0)
dxvar = dvar * (-1 / x.shape[0]) * np.power(x, -1.5) * (x - sample_mean)
dx = dxminus + dxmean + dxvar

BatchNorm Computational Graph I used for deriving

最佳答案

您的 dx 公式看起来不正确,因为 x 节点将从其他两个节点接收向后消息(一个是总和,另一个是意思是),看起来您只计算一个组件:

backprop

所以它应该看起来像这样:

dx1 = dxmu1 + dxmu2
dmu = -1 * np.sum(dxmu1+dxmu2, axis=0)
dx2 = 1. /N * np.ones((N,D)) * dmu
dx = dx1 + dx2

图片来自this wonderful post 。您也可以在那里找到完整的代码。

关于machine-learning - Batchnorm 中的反向传播更新错误,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50136187/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com