gpt4 book ai didi

tensorflow - 如何在 Tensorflow 中使用 stop_gradient

转载 作者:行者123 更新时间:2023-12-03 07:01:38 30 4
gpt4 key购买 nike

我想知道如何在tensorflow中使用stop_gradient,并且文档对我来说不清楚。

我目前正在使用 stop_gradient 来生成损失函数的梯度。 CBOW word2vec 模型中的词嵌入。我只想获取值,而不是进行反向传播(因为我正在生成对抗性示例)。

目前,我正在使用代码:

lossGrad = gradients.gradients(loss, embed)[0]
real_grad = lossGrad.eval(feed_dict)

但是当我运行这个时,它无论如何都会进行反向传播!我做错了什么,同样重要的是,我该如何解决这个问题?

澄清:为了澄清“反向传播”,我的意思是“计算值并更新模型参数”。

更新

如果我在第一个训练步骤后运行上面的两行代码,那么在 100 个训练步骤后我会得到与不运行这两行代码时不同的损失。我可能从根本上误解了有关 Tensorflow 的某些内容。

我尝试在图形声明的开头和每个训练步骤之前使用 set_random_seed 进行设置。多次运行之间的总损失是一致的,但包含/排除这两条线之间的总损失却不一致。因此,如果不是 RNG 导致了差异,也不是训练步骤之间模型参数的意外更新,您知道什么会导致这种行为吗?

解决方案

好吧,有点晚了,但我是这样解决的。我只想优化一些而不是全部变量。我认为防止优化某些变量的方法是使用 stop_grad - 但我从未找到一种方法来实现这一点。也许有办法,但对我有用的是调整我的优化器以仅优化变量列表。所以代替:

opt = tf.train.GradientDescentOptimizer(learning_rate=eta)
train_op = opt.minimize(loss)

我用过:

opt = tf.train.GradientDescentOptimizer(learning_rate=eta)
train_op = opt.minimize(loss, var_list=[variables to optimize over])

这阻止了 opt 更新不在 var_list 中的变量。希望它也适合您!

最佳答案

tf.stop_gradient提供了一种在反向传播期间不计算某些变量的梯度的方法。

例如,在下面的代码中,我们有三个变量,w1 , w2 , w3并输入x 。损失为square((x1.dot(w1) - x.dot(w2 * w3))) 。我们希望将这种损失最小化为 w1但想保留w2w3固定的。为了实现这一点,我们只需输入 tf.stop_gradient(tf.matmul(x, w2*w3)) .

在下图中,我绘制了如何 w1 , w2 ,和w3从它们的初始值作为训练迭代的函数。可见w2w3当 w1 改变时保持固定,直到它等于 w2 * w3 .

一张图片显示 w1 只学习但不学习 w2w3 :

An image showing that w1 only learns but not w2 and w3

import tensorflow as tf
import numpy as np

w1 = tf.get_variable("w1", shape=[5, 1], initializer=tf.truncated_normal_initializer())
w2 = tf.get_variable("w2", shape=[5, 1], initializer=tf.truncated_normal_initializer())
w3 = tf.get_variable("w3", shape=[5, 1], initializer=tf.truncated_normal_initializer())
x = tf.placeholder(tf.float32, shape=[None, 5], name="x")


a1 = tf.matmul(x, w1)
a2 = tf.matmul(x, w2*w3)
a2 = tf.stop_gradient(a2)
loss = tf.reduce_mean(tf.square(a1 - a2))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)
gradients = optimizer.compute_gradients(loss)
train_op = optimizer.apply_gradients(gradients)

关于tensorflow - 如何在 Tensorflow 中使用 stop_gradient,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/33727935/

30 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com