gpt4 book ai didi

python - Tensorflow:在另一个不同的模型中使用在一个模型中训练的权重

转载 作者:太空狗 更新时间:2023-10-30 01:27:25 31 4
gpt4 key购买 nike

我正在尝试使用小批量在 Tensorflow 中训练 LSTM,但训练完成后我想通过一次提交一个示例来使用该模型。我可以在 Tensorflow 中设置图表来训练我的 LSTM 网络,但之后我无法按我想要的方式使用训练后的结果。

设置代码看起来像这样:

#Build the LSTM model.
cellRaw = rnn_cell.BasicLSTMCell(LAYER_SIZE)
cellRaw = rnn_cell.MultiRNNCell([cellRaw] * NUM_LAYERS)

cell = rnn_cell.DropoutWrapper(cellRaw, output_keep_prob = 0.25)

input_data = tf.placeholder(dtype=tf.float32, shape=[SEQ_LENGTH, None, 3])
target_data = tf.placeholder(dtype=tf.float32, shape=[SEQ_LENGTH, None])
initial_state = cell.zero_state(batch_size=BATCH_SIZE, dtype=tf.float32)

with tf.variable_scope('rnnlm'):
output_w = tf.get_variable("output_w", [LAYER_SIZE, 6])
output_b = tf.get_variable("output_b", [6])

outputs, final_state = seq2seq.rnn_decoder(input_list, initial_state, cell, loop_function=None, scope='rnnlm')
output = tf.reshape(tf.concat(1, outputs), [-1, LAYER_SIZE])
output = tf.nn.xw_plus_b(output, output_w, output_b)

...注意两个占位符,input_data 和 target_data。我没有费心包括优化器设置。训练完成并结束训练类(class)后,我想设置一个新 session ,使用经过训练的 LSTM 网络,其输入由完全不同的占位符提供,例如:

with tf.Session() as sess:
with tf.variable_scope("simulation", reuse=None):
cellSim = cellRaw
input_data_sim = tf.placeholder(dtype=tf.float32, shape=[1, 1, 3])
initial_state_sim = cell.zero_state(batch_size=1, dtype=tf.float32)
input_list_sim = tf.unpack(input_data_sim)

outputsSim, final_state_sim = seq2seq.rnn_decoder(input_list_sim, initial_state_sim, cellSim, loop_function=None, scope='rnnlm')
outputSim = tf.reshape(tf.concat(1, outputsSim), [-1, LAYER_SIZE])

with tf.variable_scope('rnnlm'):
output_w = tf.get_variable("output_w", [LAYER_SIZE, nOut])
output_b = tf.get_variable("output_b", [nOut])

outputSim = tf.nn.xw_plus_b(outputSim, output_w, output_b)

第二部分返回以下错误:

tensorflow.python.framework.errors.InvalidArgumentError: You must feed a value for placeholder tensor 'Placeholder' with dtype float
[[Node: Placeholder = Placeholder[dtype=DT_FLOAT, shape=[], _device="/job:localhost/replica:0/task:0/cpu:0"]()]]

...大概是因为我正在使用的图仍然有旧的训练占位符附加到经过训练的 LSTM 节点。 “提取”经过训练的 LSTM 并将其放入具有不同输入样式的新的不同图形中的正确方法是什么? Tensorflow 的变量作用域功能似乎可以解决类似的问题,但示例 in the documentation所有人都在谈论使用变量范围作为管理变量名称的一种方式,以便同一段代码将在同一图中生成相似的子图。 “重用”功能似乎接近我想要的功能,但我发现上面链接的 Tensorflow 文档根本不清楚它的作用。不能给单元格本身命名(换句话说,

cellRaw = rnn_cell.MultiRNNCell([cellRaw] * NUM_LAYERS, name="multicell")

无效),虽然我可以给 seq2seq.rnn_decoder() 命名,但如果我原封不动地使用该节点,我大概无法删除 rnn_cell.DropoutWrapper()。

问题:

将经过训练的 LSTM 权重从一个图移动到另一个图的正确方法是什么?

开始新 session “释放资源”但不会删除内存中构建的图形是否正确?

在我看来,“重用”功能允许 Tensorflow 在当前变量范围之外搜索具有相同名称(存在于不同范围内)的变量,并在当前范围内使用它们。这样对吗?如果是,那么链接到该变量的非当前范围的所有图边会发生什么情况?如果不是,如果您尝试在两个不同的范围内使用相同的变量名称,为什么 Tensorflow 会抛出错误?在两个不同的范围内定义两个具有相同名称的变量似乎是完全合理的,例如conv1/sum1 和 conv2/sum1。

在我的代码中,我在一个新的范围内工作,但如果没有将数据从初始默认范围输入到占位符,该图将无法运行。由于某种原因,默认范围是否始终在“范围内”?

如果图形边可以跨越不同的范围,并且不同范围内的名称不能共享,除非它们引用完全相同的节点,那么这似乎首先违背了拥有不同范围的目的。我在这里误解了什么?

谢谢!

最佳答案

将经过训练的 LSTM 权重从一个图移动到另一个图的正确方法是什么?

您可以先创建解码图(使用保存器对象来保存参数),然后创建一个 GraphDef 对象,您可以将其导入更大的训练图中:

basegraph = tf.Graph()
with basegraph.as_default():
***your graph***

traingraph = tf.Graph()
with traingraph.as_default():
tf.import_graph_def(basegraph.as_graph_def())
***your training graph***

确保在为新图表启动 session 时加载变量。

我没有使用此功能的经验,因此您可能需要多研究一下

开始一个新 session “释放资源”但不会删除内存中构建的图形是否正确?

是的,图形对象仍然持有它

在我看来,“重用”功能允许 Tensorflow 在当前变量范围之外搜索具有相同名称(存在于不同范围内)的变量,并在当前范围内使用它们。这样对吗?如果是,那么链接到该变量的非当前范围的所有图边会发生什么情况?如果不是,如果您尝试在两个不同的范围内使用相同的变量名称,为什么 Tensorflow 会抛出错误?在两个不同的范围内定义两个具有相同名称的变量似乎是完全合理的,例如conv1/sum1 和 conv2/sum1。

不,重用是确定在现有名称上使用 get_variable 时的行为,当为真时将返回现有变量,否则将返回一个新变量。通常 tensorflow 不应该抛出错误。您确定您使用的是 tf.get_variable 而不仅仅是 tf.Variable 吗?

在我的代码中,我在一个新的范围内工作,但如果没有将数据从初始默认范围输入到占位符,图表将无法运行。由于某种原因,默认范围是否始终“在范围内”?

我不太明白你的意思。不必总是使用。如果运行操作不需要占位符,则无需定义它。

如果图形边可以跨越不同的范围,并且不同范围内的名称不能共享,除非它们引用完全相同的节点,那么这似乎首先违背了拥有不同范围的目的。我在这里误解了什么?

我认为你对作用域的理解或使用有缺陷,见上文

关于python - Tensorflow:在另一个不同的模型中使用在一个模型中训练的权重,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39068703/

31 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com