gpt4 book ai didi

pytorch - 使用 autograd 计算相对于输入的输出雅可比矩阵

转载 作者:行者123 更新时间:2023-12-02 18:54:54 24 4
gpt4 key购买 nike

如果这个问题很明显或微不足道,我深表歉意。我对 pytorch 很陌生,我想了解 pytorch 中的 autograd.grad 函数。我有一个神经网络 G,它接受输入 (x,t) 和输出 (u,v)。这是 G 的代码:

class GeneratorNet(torch.nn.Module):
"""
A three hidden-layer generative neural network
"""

def __init__(self):
super(GeneratorNet, self).__init__()
self.hidden0 = nn.Sequential(
nn.Linear(2, 100),
nn.LeakyReLU(0.2)
)

self.hidden1 = nn.Sequential(
nn.Linear(100, 100),
nn.LeakyReLU(0.2)
)

self.hidden2 = nn.Sequential(
nn.Linear(100, 100),
nn.LeakyReLU(0.2)
)

self.out = nn.Sequential(
nn.Linear(100, 2),
nn.Tanh()
)

def forward(self, x):
x = self.hidden0(x)
x = self.hidden1(x)
x = self.hidden2(x)
x = self.out(x)
return x

或者简单地 G(x,t) = (u(x,t), v(x,t)) 其中 u(x,t) 和 v(x,t) 是标量值。目标:计算$\frac{\partial u(x,t)}{\partial x}$和$\frac{\partial u(x,t)}{\partial t}$。在每个训练步骤中,我都有一个大小为 100 美元的小批量,因此 u(x,t) 是一个 [100,1] 张量。这是我计算偏导数的尝试,其中坐标是输入 (x,t),就像下面一样,我也在坐标中添加了 requires_grad_(True) 标志:

tensor = GeneratorNet(coords)
tensor.requires_grad_(True)
u, v = torch.split(tensor, 1, dim=1)
du = autograd.grad(u, coords, grad_outputs=torch.ones_like(u), create_graph=True,
retain_graph=True, only_inputs=True, allow_unused=True)[0]

du 现在是一个 [100,2] 张量。问题:这是小批量的 100 个输入点的部分张量吗?

还有类似的问题,例如计算输出相对于输入的导数,但我无法真正弄清楚发生了什么。如果这个问题已经得到解答或微不足道,我再次表示歉意。非常感谢。

最佳答案

您发布的代码应该为您提供第一个输出的偏导数。输入。但是,您还必须在输入上设置 requires_grad_(True) ,否则 PyTorch 不会从输入开始构建计算图,因此无法计算它们的梯度。

此版本的代码示例计算 dudv:

net = GeneratorNet()
coords = torch.randn(10, 2)
coords.requires_grad = True
tensor = net(coords)
u, v = torch.split(tensor, 1, dim=1)
du = torch.autograd.grad(u, coords, grad_outputs=torch.ones_like(u))[0]
dv = torch.autograd.grad(v, coords, grad_outputs=torch.ones_like(v))[0]

您还可以计算单个输出的偏导数:

net = GeneratorNet()
coords = torch.randn(10, 2)
coords.requires_grad = True
tensor = net(coords)
u, v = torch.split(tensor, 1, dim=1)
du_0 = torch.autograd.grad(u[0], coords)[0]

其中du_0 == du[0]

关于pytorch - 使用 autograd 计算相对于输入的输出雅可比矩阵,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59161001/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com