gpt4 book ai didi

python - 在 tensorflow 中的急切执行训练期间修复变量的一部分

转载 作者:行者123 更新时间:2023-12-01 07:21:20 25 4
gpt4 key购买 nike

有没有办法在急切执行更新步骤期间仅更新一些变量?考虑这个最小的工作示例:

import tensorflow as tf
tf.enable_eager_execution()
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)

x = tf.Variable([1.0, 2.0])

def train(x):
with tf.GradientTape() as tape:
loss = x[0]**2 + x[1]**2 + 1/(x[0]+x[1])
variables = [x]
grads = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(grads, variables))

for _ in range(2000):
train(x)
print(x.numpy())

收敛于[0.5, 0.5]。我想将 x[0] 的值修复为其初始值,同时保持其他所有内容不变。到目前为止我尝试过的:

  • 在训练步骤中添加 x[0].assign(1.0) 操作,这会不必要地增大图表
  • 更改 variables = [x[:-1]] 会给出 ValueError: Nogradingsprovided foranyvariable: ['tf.Tensor([1.], shape=(1 ,), dtype=float32)']
  • 添加grads = [grads[0][1:]],这会给出tensorflow.python.framework.errors_impl.InvalidArgumentError:var和delta没有相同的形状[2] [1] [操作:ResourceApplyGradientDescent]
  • 同时执行这两项操作,会出现 TypeError: 'NoneType' object is not subscriptable

对于这个 MWE,我可以轻松地使用两个单独的变量,但我对一般情况感兴趣,在这种情况下我只想更新数组的已知切片。

最佳答案

您可以将不想更新的索引的梯度设置为0。在下面的代码片段中,mask张量表示我们想要更新哪些元素(值1 ),以及我们不想更新哪些元素(值0)。

import tensorflow as tf
tf.enable_eager_execution()

optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)

x = tf.Variable([1.0, 2.0])
mask = tf.constant([0.0, 1.0])

def train(x):
with tf.GradientTape() as tape:
loss = x[0]**2 + x[1]**2 + 1/(x[0]+x[1])
variables = [x]

grads = tape.gradient(loss, variables) * mask
optimizer.apply_gradients(zip(grads, variables))

for _ in range(100):
train(x)
print(x.numpy())

解决您的问题的另一个可能的解决方案是停止 x[0] 所依赖的操作的梯度。例如:

optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)

x = tf.Variable([1.0, 2.0])

def train(x):
with tf.GradientTape() as tape:
loss = tf.stop_gradient(x[0])**2 + x[1]**2 + 1/(tf.stop_gradient(x[0])+x[1])
variables = [x]

grads = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(grads, variables))

for _ in range(100):
train(x)
print(x.numpy())

关于python - 在 tensorflow 中的急切执行训练期间修复变量的一部分,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57680916/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com