gpt4 book ai didi

variables - 默认情况下是否可以训练局部变量?

转载 作者:行者123 更新时间:2023-12-05 05:12:14 25 4
gpt4 key购买 nike

当我浏览指南时 https://www.tensorflow.org/guide/variables ,我对下面的描述感到困惑(粗体):

By default every tf.Variable gets placed in the following two collections:

  • tf.GraphKeys.GLOBAL_VARIABLES --- variables that can be shared across multiple devices,
  • tf.GraphKeys.TRAINABLE_VARIABLES --- variables for which TensorFlow will calculate gradients.

If you don't want a variable to be trainable, add it to the tf.GraphKeys.LOCAL_VARIABLES collection instead. For example, the following snippet demonstrates how to add a variable named my_local to this collection:

my_local = tf.get_variable("my_local", shape=(), collections [tf.GraphKeys.LOCAL_VARIABLES])`

Alternatively, you can specify trainable=False as an argument to tf.get_variable:

my_non_trainable = tf.get_variable("my_non_trainable", shape=(), trainable=False)

但是当我创建一个局部变量时,它会自动添加到集合 tf.GraphKeys.TRAINABLE_VARIABLES 中,这意味着它是可训练的。那么,局部变量是否可训练?

最佳答案

文档确实令人困惑。默认情况下,局部变量也会添加到可训练变量的集合中。您可以通过检查 tf.trainable_variables() 来检查这一点。因此,看起来要使局部变量不可可训练,将其添加到LOCAL_VARIABLES 集合中是不够的,但您需要关键字trainable=False

这是一个简短的脚本,显示局部变量和全局变量都在训练循环中更新:

import tensorflow as tf

my_local = tf.get_variable("my_local", shape=(), collections=[tf.GraphKeys.LOCAL_VARIABLES],
initializer=tf.constant_initializer(1.0))
my_global = tf.get_variable("my_global", shape=(),
initializer=tf.constant_initializer(2.0))

target_value = tf.constant(4.0)
loss = tf.abs(my_local + my_global - target_value)
optim = tf.train.AdamOptimizer(learning_rate=1.0).minimize(loss)

for v in tf.trainable_variables():
print(v.name)

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
print("local init: ", sess.run(my_local))
print("global init: ", sess.run(my_global))
for i in range(2):
_, l = sess.run([optim, loss])
print("loss {:.4f}".format(l))
print("local: ", sess.run(my_local))
print("global: ", sess.run(my_global))

打印

my_local:0
my_global:0
local init: 1.0
global init: 2.0
loss 1.0000
local: 1.9999996
global: 2.9999995
loss 1.0000
local: 1.9473683
global: 2.9473681

如果您在对 tf.get_variable 的调用中设置 trainable=False,则 my_local 的值不会改变。

关于variables - 默认情况下是否可以训练局部变量?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54780410/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com