gpt4 book ai didi

deep-learning - Theano - 如何覆盖部分操作图的梯度

转载 作者:行者123 更新时间:2023-12-04 05:56:09 25 4
gpt4 key购买 nike

我手头有一个相当复杂的模型。该模型有多个具有线性结构的部分:

y = theano.tensor.dot(W,x) + b

我想构建一个优化器,它使用自定义规则来计算所有线性 结构的梯度,同时保持其他操作不变。 为我的模型的所有线性部分覆盖梯度操作的最简单方法是什么?最好不需要编写新的操作。

最佳答案

所以,我花了一些时间研究 PR (截至 2017 年 1 月 13 日未合并 已合并)对于 Theano,它使用户能够部分覆盖 theano.OpFromGraph 实例的梯度。覆盖是通过符号图完成的,因此您仍然可以获得 theano 优化的全部好处。

典型用例:

  • 数字安全考虑
  • 重新缩放/剪裁渐变
  • 像黎曼自然梯度这样的特殊梯度程序

要制作具有覆盖梯度的 Op:

  1. 制作所需的计算图
  2. 为你的 Op 的梯度创建一个 OpFromGraph 实例(或一个 python 函数)
  3. 让 OfG 实例成为您的 Op,并设置 grad_overrides 参数
  4. 调用 OfG 实例来构建您的模型

定义一个 OpFromGraph 就像编译一个 theano 函数,但有一些不同:

  • 不支持 updatesgivens(截至 2017 年 1 月)
  • 你得到一个符号运算而不是一个数值函数

例子:

'''
This creates an atan2_safe Op with smoothed gradient at (0,0)
'''
import theano as th
import theano.tensor as T

# Turn this on if you want theano to build one large graph for your model instead of precompiling the small graph.
USE_INLINE = False
# In a real case you would set EPS to a much smaller value
EPS = 0.01

# define a graph for needed Op
s_x, s_y = T.scalars('xy')
s_darg = T.scalar(); # backpropagated gradient
s_arg = T.arctan2(s_y, s_x)
s_abs2 = T.sqr(s_x) + T.sqr(s_y) + EPS
s_dx = -s_y / s_abs2
s_dy = s_x / s_abs2

# construct OfG with gradient overrides
# NOTE: there are unused inputs in the gradient expression,
# however the input count must match, so we pass
# on_unused_input='ignore'
atan2_safe_grad = th.OpFromGraph([s_x, s_y, s_darg], [s_dx, s_dy], inline=USE_INLINE, on_unused_input='ignore')
atan2_safe = th.OpFromGraph([s_x, s_y], [s_arg], inline=USE_INLINE, grad_overrides=atan2_safe_grad)

# build graph using the new Op
x, y = T.scalar(), T.scalar()
arg = atan2_safe(x, y)
dx, dy = T.grad(arg, [x, y])
fn = th.function([x, y], [dx, dy])
fn(1., 0.) # gives [-0.0, 0.99099]
fn(0., 0.) # gives [0.0, 0.0], no more annoying nan!

注意:theano.OpFromGraph 在很大程度上仍处于试验阶段,可能会出现错误。

关于deep-learning - Theano - 如何覆盖部分操作图的梯度,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/40613225/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com