gpt4 book ai didi

tensorflow - 如何恢复 LSTM 层

转载 作者:行者123 更新时间:2023-12-04 01:54:45 25 4
gpt4 key购买 nike

如果我能在保存和恢复 LSTM 方面得到一些帮助,我将不胜感激。

我有这个 LSTM 层 -

# LSTM cell
cell = tf.contrib.rnn.LSTMCell(n_hidden)
output, current_state = tf.nn.dynamic_rnn(cell, word_vectors, dtype=tf.float32)

outputs = tf.transpose(output, [1, 0, 2])
last = tf.gather(outputs, int(outputs.get_shape()[0]) - 1)

# Saver function
saver = tf.train.Saver()
saver.save(sess, 'test-model')

saver 保存模型并允许我保存和恢复 LSTM 的权重和偏差。但是,我需要恢复这个 LSTM 层并为其提供一组新的输入。

为了恢复整个模型,我正在做:
with tf.Session() as sess:
saver = tf.train.import_meta_graph('test-model.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))
  • 我是否可以使用预训练的权重和偏差来初始化 LSTM 单元?
  • 如果没有,我该如何恢复这个 LSTM 层?

  • 非常感谢!

    最佳答案

    您已经在加载模型,以及模型的权重。您需要做的就是使用 get_tensor_by_name从图中获取任何张量并将其用于推理。

    例子:

    with tf.Session() as sess:
    saver = tf.train.import_meta_graph('test-model.meta')
    saver.restore(sess, tf.train.latest_checkpoint('./'))

    # Get the tensors by their variable name
    word_vec = = detection_graph.get_tensor_by_name('word_vec:0')
    output_tensor = detection_graph.get_tensor_by_name('outputs:0')

    sess.run(output_tensor, feed_dict={word_vec: ...})

    在上面的例子中 word_vecoutputs是在创建图形期间分配给张量的名称。确保您指定了名称,以便可以通过名称来称呼它们。

    关于tensorflow - 如何恢复 LSTM 层,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45154459/

    25 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com