gpt4 book ai didi

tensorflow - tf.gradients,我如何理解 `grad_ys` 并使用它?

转载 作者:行者123 更新时间:2023-12-05 03:05:27 26 4
gpt4 key购买 nike

tf.gradients中,有一个关键字参数grad_ys

grad_ys is a list of tensors of the same length as ys that holds the initial gradients for each y in ys. When grad_ys is None, we fill in a tensor of ‘1’s of the shape of y for each y in ys. A user can provide their own initial grad_ys to compute the derivatives using a different initial gradient for each y (e.g., if one wanted to weight the gradient differently for each value in each y).

为什么这里需要grads_ys?这里的文档是隐含的。能否请您给出一些具体的用途和代码?

我的 tf.gradients 示例代码是

In [1]: import numpy as np

In [2]: import tensorflow as tf

In [3]: sess = tf.InteractiveSession()

In [4]: X = tf.placeholder("float", shape=[2, 1])

In [5]: Y = tf.placeholder("float", shape=[2, 1])

In [6]: W = tf.Variable(np.random.randn(), name='weight')

In [7]: b = tf.Variable(np.random.randn(), name='bias')

In [8]: pred = tf.add(tf.multiply(X, W), b)

In [9]: cost = 0.5 * tf.reduce_sum(tf.pow(pred-Y, 2))

In [10]: grads = tf.gradients(cost, [W, b])

In [11]: sess.run(tf.global_variables_initializer())

In [15]: W_, b_, pred_, cost_, grads_ = sess.run([W, b, pred, cost, grads],
feed_dict={X: [[2.0], [3.]], Y: [[3.0], [2.]]})

最佳答案

grad_ys 仅在高级用例中需要。以下是您可以考虑的方式。

tf.gradients 允许您计算 tf.gradients(y, x, grad_ys) = grad_ys * dy/dx。换句话说,grad_ys 是每个 y 的乘数。在这种表示法中,提供这个论点似乎很愚蠢,因为一个人应该能够自己倍增,即 tf.gradients(y, x, grad_ys) = grad_ys * tf.gradients(y, x) .不幸的是,这种等式并不成立,因为当向后计算梯度时,我们在每一步之后执行归约(通常是求和)以获得“中间损失”。

此功能在许多情况下都很有用。文档字符串中提到了一个。这是另一个。记住链式法则 - dz/dx = dz/dy * dy/dx。假设我们想要计算 dz/dxdz/dy 不可微,我们只能对其进行近似计算。假设我们以某种方式计算近似值并将其称为 approx。然后,dz/dx = tf.gradients(y, x, grad_ys=approx)

另一个用例是当您有一个具有“巨大扇入”的模型时。假设您有 100 个输入源,它们经过几层(称为“100 个分支”),在 y 处合并,然后再经过 10 个层,直到出现 loss。一次计算整个模型的所有梯度(这需要记住许多激活)可能不适合内存。一种方法是先计算 d(loss)/dy。然后,使用 tf.gradients(y, branch_i_variables, grad_ys=d(loss)/dy) branch_i 中变量相对于 loss 的梯度。使用它(以及我正在跳过的更多细节),您可以减少峰值内存需求。

关于tensorflow - tf.gradients,我如何理解 `grad_ys` 并使用它?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50967885/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com