gpt4 book ai didi

python - 想要通过占位符选择运行时训练的变量

转载 作者:太空宇宙 更新时间:2023-11-03 20:16:10 24 4
gpt4 key购买 nike

我正在寻找一种方法来根据我的纪元 ID 选择可训练变量以在运行时进行更新。我知道我可以将变量范围下的一组变量命名为 tf.variable_scope 。

如果我将占位符创建为 train_vars = tf.placeholder(shape = [None], dtype = type(tf.GraphKeys)),则会引发以下错误:

TypeError: Expected DataType for argument 'dtype' not <class 'type'>.

通过占位符传递此可训练列表的正确方法是什么,还是没有办法?

最佳答案

一种方法是使用 tf.stop_gradienttf.cond选择梯度是否应该反向传播到变量,例如这样:

v = tf.Variable(...)
v_is_trainable = tf.placeholder((), tf.bool) # Or tf.placeholder_with_default
v_value = tf.cond(v_is_trainable, lambda: v, lambda: tf.stop_gradient(v))

然后只需将 v_is_trainable: Truev_is_trainable: False 添加到 feed_dict 即可使该变量可训练或不可训练。

编辑:如果您想通过指示可训练变量的集合来选择变量,您可以这样做

trainable_vars = tf.placeholder([None], tf.string)
v = tf.Variable(...)
# The variable is trainable if its name is in trainable_vars
v_value = tf.cond(tf.reduce_any(tf.equal(v.name, trainable_vars)),
lambda: v, lambda: tf.stop_gradient(v))

然后,如果您在名为“MY_VARS”的集合中拥有要训练的所有变量,则可以在 feed_dict 中给出:

trainable_vars: [v.name for v in tf.get_collection('MY_VARS')]

关于python - 想要通过占位符选择运行时训练的变量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58431623/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com