gpt4 book ai didi

python-3.x - 如何包装 PyTorch 函数并实现 autograd?

转载 作者:行者123 更新时间:2023-12-03 23:21:26 25 4
gpt4 key购买 nike

我正在学习 Defining new autograd functions 上的 PyTorch 教程.我要实现的 autograd 函数是 torch.nn.functional.max_pool1d 的包装器。 .这是我到目前为止所拥有的:

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as tag

class SquareAndMaxPool1d(tag.Function):

@staticmethod
def forward(ctx, input, kernel_size, stride=None, padding=0, dilation=1, \
return_indices=False, ceil_mode=False):
ctx.save_for_backward( input )

inputC = input.clone() #copy input
inputC *= inputC

output = F.max_pool1d(inputC, kernel_size, stride=stride, \
padding=padding, dilation=dilation, \
return_indices=return_indices, \
ceil_mode=ceil_mode)

return output

@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = get_max_pool1d_grad_somehow(grad_output)
return 2.0*input*grad_input

我的问题是:如何获得包装函数的渐变?我知道考虑到我提供的示例非常简单,可能还有其他方法可以做到这一点,但我想做的适合这个框架并且需要我实现 autograd功能。

编辑:检查后 this blog post我决定尝试以下 backward :
def backward(ctx, grad_output):
input, output = ctx.saved_tensors
grad_input = output.backward(grad_output)
return 2.0*input*grad_input

output添加到已保存的变量中。然后我运行以下代码:
x = np.random.randn(1,1,5)
xT = torch.from_numpy(x)
xT.requires_grad=True
f = SquareAndMaxPool1d.apply
s = torch.sum(f(xT,2))
s.backward()

我得到 Bus error: 10 .

说, xTtensor([[[ 1.69533562, -0.21779421, 2.28693953, -0.86688095, -1.01033497]]], dtype=torch.float64) ,那么我希望找到 xT.gradtensor([[[ 3.39067124, -0. , 9.14775812, -0. , -2.02066994]]], dtype=torch.float64)调用 s.backward() 后(即 2*x*grad_of_max_pool ,其中 grad_of_max_pool 包含 tensor([[[1., 0., 2., 0., 1.]]], dtype=torch.float64) )。

我已经弄清楚为什么我会收到 Bus error: 10 .上面的代码似乎导致了我的 backward 的递归调用。在 grad_input = output.backward(grad_output) .所以我需要找到其他方法来获得 max_pool1d 的梯度.我知道如何在纯 Python 中实现它,但结果会比我可以包装库代码要慢得多。

最佳答案

你选择了一个相当不幸的例子。 torch.nn.functional.max_pool1d不是 torch.autograd.Function 的实例,因为它是 PyTorch 内置的,在 C++ 代码中定义并带有 autogenerated Python 绑定(bind)。我不确定是否有可能获得 backward属性通过其接口(interface)。

首先,如果你没有注意到,你不需要为这个公式的反向传播编写任何自定义代码,因为电源操作和 max_pool1d已经定义了它,因此它们的组成也包含在 autograd 中。假设您的目标是一项练习,我建议您更多地手动进行(不要退回到 backwardmax_pool1d )。下面是一个例子

import torch
import torch.nn.functional as F
import torch.autograd as tag

class SquareAndMaxPool1d(tag.Function):
@staticmethod
def forward(ctx, input, kernel_size, **kwargs):
# we're gonna need indices for backward. Currently SquareAnd...
# never actually returns indices, I left it out for simplicity
kwargs['return_indices'] = True

input_sqr = input ** 2
output, indices = F.max_pool1d(input_sqr, kernel_size, **kwargs)
ctx.save_for_backward(input, indices)

return output

@staticmethod
def backward(ctx, grad_output):
input, indices = ctx.saved_tensors

# first we need to reconstruct the gradient of `max_pool1d`
# by putting all the output gradient elements (corresponding to
# input elements which made it through the max_pool1d) in their
# respective places, the rest has gradient of 0. We do it by
# scattering it against a tensor of 0s
grad_output_unpooled = torch.zeros_like(input)
grad_output_unpooled.scatter_(2, indices, grad_output)

# then incorporate the gradient of the "square" part of your
# operator
grad_input = 2. * input * grad_output_unpooled

# the docs for backward
# https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function.backward
# say that "it should return as many tensors, as there were inputs
# to forward()". It fails to mention that if an argument was not a
# tensor, it should return None (I remember reading this somewhere,
# but can't find it anymore). Anyway, we need to
# return a (grad_input, None) tuple to avoid a complaint that two
# outputs were expected
return grad_input, None

然后我们可以使用 numerical gradient checker验证操作是否按预期工作。
f = SquareAndMaxPool1d.apply
xT = torch.randn(1, 1, 6, requires_grad=True, dtype=torch.float64)
tag.gradcheck(lambda t: f(t, 2), xT)

如果这不能解决您关于如何获得 backward 的问题,我很抱歉。的 max_pool1d ,但希望您发现我的回答足够有用。

关于python-3.x - 如何包装 PyTorch 函数并实现 autograd?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54586938/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com