gpt4 book ai didi

numpy - 正向与反向模式区分 - Pytorch

转载 作者:行者123 更新时间:2023-12-04 10:59:52 24 4
gpt4 key购买 nike

Learning PyTorch with Examples 的第一个示例中,作者演示了如何使用 numpy 创建神经网络。为方便起见,他们的代码粘贴在下面:

# from: https://pytorch.org/tutorials/beginner/pytorch_with_examples.html
# -*- coding: utf-8 -*-
import numpy as np

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = np.random.randn(N, D_in)
y = np.random.randn(N, D_out)

# Randomly initialize weights
w1 = np.random.randn(D_in, H)
w2 = np.random.randn(H, D_out)

learning_rate = 1e-6
for t in range(500):
# Forward pass: compute predicted y
h = x.dot(w1)
h_relu = np.maximum(h, 0)
y_pred = h_relu.dot(w2)

# Compute and print loss
loss = np.square(y_pred - y).sum()
print(t, loss)

# Backprop to compute gradients of w1 and w2 with respect to loss
grad_y_pred = 2.0 * (y_pred - y)
grad_w2 = h_relu.T.dot(grad_y_pred)
grad_h_relu = grad_y_pred.dot(w2.T)
grad_h = grad_h_relu.copy()
grad_h[h < 0] = 0
grad_w1 = x.T.dot(grad_h)

# Update weights
w1 -= learning_rate * grad_w1
w2 -= learning_rate * grad_w2

令我困惑的是为什么要计算 w1 和 w2 的梯度 关于损失 (倒数第二个代码块)。

通常会发生相反的计算:损失的梯度是根据权重计算的,如下所示:
  • “在训练神经网络时,我们将成本(描述神经网络执行情况的值)视为参数(描述网络行为方式的数字)的函数。 我们想计算成本的导数关于所有参数,用于梯度下降。现在,神经网络中通常有数百万甚至数千万个参数。因此,反向模式微分,在神经网络的上下文中称为反向传播,给出我们大幅加速!” (Colah's blog)。

  • 所以我的问题是:与正常的反向传播计算相比,为什么上面示例中的推导计算顺序相反?

    最佳答案

    似乎是评论中的错字。他们实际上是在计算 loss 的梯度。 w.r.t. w2w1 .

    让我们快速推导出 loss 的梯度w.r.t. w2只是要确定。通过检查您的代码,我们有

    enter image description here

    使用微积分中的链式法则

    enter image description here .

    每一项都可以使用矩阵演算的基本规则来表示。这些原来是

    enter image description here



    enter image description here .

    将这些项代入我们得到的初始方程

    enter image description here .

    完美匹配所描述的表达式

    grad_y_pred = 2.0 * (y_pred - y)       # gradient of loss w.r.t. y_pred
    grad_w2 = h_relu.T.dot(grad_y_pred) # gradient of loss w.r.t. w2

    在您提供的反向传播代码中。

    关于numpy - 正向与反向模式区分 - Pytorch,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58886606/

    24 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com