gpt4 book ai didi

python - 只训练tensorflow中的一些变量

转载 作者:太空宇宙 更新时间:2023-11-03 14:59:56 24 4
gpt4 key购买 nike

我正在使用 tensorflow 进行梯度体面分类。

train_op = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

这里的 cost 是我在优化中使用的成本函数。在 Session 中启动 Graph 后,Graph 可以作为:

sess.run(train_op, feed_dict)

这样,成本函数中的所有变量都将更新,以最小化成本。

这是我的问题。训练时如何只更新成本函数中的一些变量..?有没有办法将创建的变量转换为常量或其他东西......?

最佳答案

有几个很好的答案,这个主题应该已经关闭了: stackoverflow Quora

只是为了避免人们再次点击这里:

tensorflow 优化器的最小化函数为此目的采用了一个 var_list 参数:

first_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
"scope/prefix/for/first/vars")
first_train_op = optimizer.minimize(cost, var_list=first_train_vars)

second_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
"scope/prefix/for/second/vars")
second_train_op = optimizer.minimize(cost, var_list=second_train_vars)

我照原样从mrry

要获取您应该使用的名称列表而不是 "scope/prefix/for/second/vars",您可以使用:

tf.get_default_graph().get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)

关于python - 只训练tensorflow中的一些变量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/38989046/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com