gpt4 book ai didi

python - pytorch 是否对其计算图进行急切修剪?

转载 作者:太空宇宙 更新时间:2023-11-04 04:19:26 25 4
gpt4 key购买 nike

这是一个非常简单的例子:

import torch

x = torch.tensor([1., 2., 3., 4., 5.], requires_grad=True)
y = torch.tensor([2., 2., 2., 2., 2.], requires_grad=True)
z = torch.tensor([1., 1., 0., 0., 0.], requires_grad=True)

s = torch.sum(x * y * z)
s.backward()

print(x.grad)

这将打印,

tensor([2., 2., 0., 0., 0.]),

当然,对于 z 为零的条目,ds/dx 为零。

我的问题是:pytorch 是否智能并在达到零时停止计算?或者实际上是在计算“2*5”,只是为了稍后做“10 * 0 = 0”?

在这个简单的例子中,它并没有太大的不同,但在我正在研究的(更大的)问题中,这会产生很大的不同。

感谢您的任何输入。

最佳答案

不,pytorch 不会在达到零时修剪任何后续计算。更糟糕的是,由于浮点运算的工作原理,所有后续的零乘法将花费与任何常规乘法大致相同的时间。

但对于某些情况,有一些解决方法,例如,如果您想使用掩蔽损失,您可以将掩蔽输出设置为零,或者将它们从梯度中分离出来。

这个例子清楚地表明了区别:

def time_backward(do_detach):
x = torch.tensor(torch.rand(100000000), requires_grad=True)
y = torch.tensor(torch.rand(100000000), requires_grad=True)
s2 = torch.sum(x * y)
s1 = torch.sum(x * y)
if do_detach:
s2 = s2.detach()
s = s1 + 0 * s2
t = time.time()
s.backward()
print(time.time() - t)

time_backward(do_detach= False)
time_backward(do_detach= True)

输出:

0.502875089645
0.198422908783

关于python - pytorch 是否对其计算图进行急切修剪?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54781966/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com