gpt4 book ai didi

python - tensorflow在梯度计算时如何处理不可微节点?

转载 作者:太空狗 更新时间:2023-10-30 01:25:02 25 4
gpt4 key购买 nike

我理解自动微分的概念,但找不到任何解释 tensorflow 如何计算不可微分函数的误差梯度,例如我的损失函数中的 tf.wheretf .cond 在我的图中。它工作得很好,但我想了解 tensorflow 如何通过这些节点反向传播错误,因为没有公式可以从它们计算梯度。

最佳答案

tf.where 的情况下,您有一个具有三个输入的函数,条件 C,值为真 T 和值false F,一个输出Out。梯度接收一个值并且必须返回三个值。目前,没有为条件计算梯度(这几乎没有意义),因此您只需为 TF 计算梯度。假设输入和输出是向量,假设 C[0]True。然后 Out[0] 来自 T[0],它的梯度应该传播回来。另一方面,F[0] 会被丢弃,所以它的梯度应该为零。如果 Out[1]False,则 F[1] 的梯度应该传播,但 T[1]< 的梯度不会传播。因此,简而言之,对于 T,您应该在 CTrue 的情况下传播给定的梯度,并在 False< 的情况下将其设为零,而 F 则相反。如果你看the implementation of the gradient of tf.where (Select operation) , 它正是这样做的:

@ops.RegisterGradient("Select")
def _SelectGrad(op, grad):
c = op.inputs[0]
x = op.inputs[1]
zeros = array_ops.zeros_like(x)
return (None, array_ops.where(c, grad, zeros), array_ops.where(
c, zeros, grad))

请注意,输入值本身不用于计算,这将由产生这些输入的操作的梯度来完成。对于 tf.condthe code is a bit more complicated ,因为在不同的上下文中使用了相同的操作(Merge),而且tf.cond内部也使用了Switch操作。不过思路是一样的。本质上,Switch 操作用于每个输入,因此被激活的输入(第一个如果条件为 True,否则第二个)获得接收到的梯度,另一个输入获得“关闭”梯度(如 None),并且不会进一步传播回来。

关于python - tensorflow在梯度计算时如何处理不可微节点?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53208334/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com