gpt4 book ai didi

python - pytorch 中的 if-else 语句和 torch.where 之间有什么区别?

转载 作者:行者123 更新时间:2023-12-04 12:24:31 38 4
gpt4 key购买 nike

查看代码片段:

import torch
x = torch.tensor([-1.], requires_grad=True)
y = torch.where(x > 0., x, torch.tensor([2.], requires_grad=True))
y.backward()
print(x.grad)

输出为 tensor([0.]) , 但

import torch
x = torch.tensor([-1.], requires_grad=True)
if x > 0.:
y = x
else:
y = torch.tensor([2.], requires_grad=True)
y.backward()
print(x.grad)

输出为 None .

我很困惑为什么 torch.where 的输出是 tensor([0.]) ?

更新

import torch
a = torch.tensor([[1,2.], [3., 4]])
b = torch.tensor([-1., -1], requires_grad=True)
a[:,0] = b

(a[0, 0] * a[0, 1]).backward()
print(b.grad)

输出为 tensor([2., 0.]) . (a[0, 0] * a[0, 1])b[1] 没有任何关系,但梯度为 b[1]0不是 None .

最佳答案

基于跟踪的 AD,如 pytorch,通过跟踪工作。您无法跟踪不是库拦截的函数调用的内容。通过使用 if像这样的声明,x之间没有联系和 y , 而与 where , xy在表达式树中链接。

现在,对于差异:

  • 在第一个片段中,0是函数 x ↦ x > 0 ? x : 2 的正确导数在点 -1 (因为消极的一面是恒定的)。
  • 在第二个片段中,正如我所说,xy 没有任何关系(在 else 分支中)。因此,y 的导数给定 x未定义,表示为 None .

  • (你甚至可以在 Python 中做这样的事情,但这需要更复杂的技术,比如源代码转换。我不认为 pytorch 是可能的。)

    关于python - pytorch 中的 if-else 语句和 torch.where 之间有什么区别?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61184437/

    38 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com