gpt4 book ai didi

python - tf.get_collection 提取一个作用域的变量

转载 作者:太空狗 更新时间:2023-10-30 02:55:36 24 4
gpt4 key购买 nike

我有 n(例如:n=3)个作用域和 x(例如:x=4)没有在每个作用域中定义的变量。范围是:

model/generator_0
model/generator_1
model/generator_2

一旦我计算了损失,我想在运行时根据一个标准从一个范围中提取并提供所有变量。因此,我选择的范围 idx 的索引是一个转换为 int32 的 argmin 张量

<tf.Tensor 'model/Cast:0' shape=() dtype=int32>

我已经试过了:

train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'model/generator_'+tf.cast(idx, tf.string)) 

这显然行不通。有什么方法可以使用 idx 将属于该特定范围的所有 x 变量传递给优化器。

提前致谢!

维尼什·斯里尼瓦桑

最佳答案

你可以在 TF 1.0 rc1 或更高版本中做这样的事情:

v = tf.Variable(tf.ones(()))
loss = tf.identity(v)
with tf.variable_scope('adamoptim') as vs:
optim = tf.train.AdamOptimizer(learning_rate=0.1).minimize(loss)
optim_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=vs.name)
print([v.name for v in optim_vars]) #=> prints lists of vars created

关于python - tf.get_collection 提取一个作用域的变量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42073239/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com