gpt4 book ai didi

python - Tensorflow:如何从 rnn_cell.BasicLSTM 和 rnn_cell.MultiRNNCell 获取所有变量

转载 作者:IT老高 更新时间:2023-10-28 21:10:00 25 4
gpt4 key购买 nike

我有一个设置,我需要在使用 tf.initialize_all_variables() 的主要初始化之后初始化 LSTM。 IE。我想调用 tf.initialize_variables([var_list])

有没有办法为两者收集所有内部可训练变量:

  • rnn_cell.BasicLSTM
  • rnn_cell.MultiRNNCell

以便我可以初始化这些参数?

我想要这个的主要原因是我不想重新初始化之前的一些训练值。

最佳答案

解决问题的最简单方法是使用变量范围。范围内的变量名称将以其名称为前缀。这是一个简短的片段:

cell = rnn_cell.BasicLSTMCell(num_nodes)

with tf.variable_scope("LSTM") as vs:
# Execute the LSTM cell here in any way, for example:
for i in range(num_steps):
output[i], state = cell(input_data[i], state)

# Retrieve just the LSTM variables.
lstm_variables = [v for v in tf.all_variables()
if v.name.startswith(vs.name)]

# [..]
# Initialize the LSTM variables.
tf.initialize_variables(lstm_variables)

它与 MultiRNNCell 的工作方式相同。

编辑:将 tf.trainable_variables 更改为 tf.all_variables()

关于python - Tensorflow:如何从 rnn_cell.BasicLSTM 和 rnn_cell.MultiRNNCell 获取所有变量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/35013080/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com