gpt4 book ai didi

machine-learning - 如何使 RNN 单元的权重在 Tensorflow 中无法训练?

转载 作者:行者123 更新时间:2023-11-30 08:26:52 25 4
gpt4 key购买 nike

我正在尝试制作一个 Tensorflow 图,其中图的一部分已经过预训练并在预测模式下运行,而其余部分则进行训练。我已经像这样定义了我的预训练单元:

rnn_cell = tf.contrib.rnn.BasicLSTMCell(100)

state0 = tf.Variable(pretrained_state0,trainable=False)
state1 = tf.Variable(pretrained_state1,trainable=False)
pretrained_state = [state0, state1]

outputs, states = tf.contrib.rnn.static_rnn(rnn_cell,
data_input,
dtype=tf.float32,
initial_state = pretrained_state)

将初始变量设置为 trainable=False 没有帮助。这些仅用于初始化权重,因此权重仍然会发生变化。

我仍然需要在训练步骤中运行优化器,因为模型的其余部分需要训练。但如何防止优化器更改此 rnn 单元中的权重?

是否存在与 trainable=False 等效的 rnn_cell?

最佳答案

您可以使用tf.stop_gradient()来防止图表的预训练部分更新其权重,也可以使用optimiser() 您可以在其中指定应训练图表的哪些部分。第二种方法涉及:

 #Create variable scope for the trainable parts of the graph: tf.variable_scope('train').

# get trainable variables
t_vars = tf.trainable_variables()
train_vars = [var for var in t_vars if var.name.startswith('train')]
# train only the variables of a particular scope
opt = optimizer.minimize(cost, var_list=train_vars)

关于machine-learning - 如何使 RNN 单元的权重在 Tensorflow 中无法训练?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44959435/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com