gpt4 book ai didi

tensorflow - 加载预训练的 word2vec 以初始化 Estimator model_fn 中的 embedding_lookup

转载 作者:行者123 更新时间:2023-12-04 17:31:18 25 4
gpt4 key购买 nike

我正在解决一个文本分类问题。我使用 Estimator 定义了我的分类器我自己的类(class)model_fn .我想用谷歌预训练的word2vec嵌入为初始值,然后针对手头的任务进一步优化它。

我看到了这个帖子:Using a pre-trained word embedding (word2vec or Glove) in TensorFlow
它解释了如何在“原始”TensorFlow 代码中进行处理。但是,我真的很想使用 Estimator类(class)。

作为扩展,我想在 Cloud ML Engine 上训练这段代码,有没有一种很好的方法可以传入具有初始值的相当大的文件?

假设我们有类似的东西:

def build_model_fn():
def _model_fn(features, labels, mode, params):
input_layer = features['feat'] #shape=[-1, params["sequence_length"]]
#... what goes here to initialize W

embedded = tf.nn.embedding_lookup(W, input_layer)
...
return predictions

estimator = tf.contrib.learn.Estimator(
model_fn=build_model_fn(),
model_dir=MODEL_DIR,
params=params)
estimator.fit(input_fn=read_data, max_steps=2500)

最佳答案

嵌入通常足够大,唯一可行的方法是使用它们来初始化 tf.Variable在你的图表中。这将允许您利用分布式等的参数服务器。

对于这个(以及其他),我建议你使用新的“核心”估计器, tf.estimator.Estimator 因为这会让事情变得更容易。

根据您提供的链接中的答案,并且知道我们想要一个变量而不是常量,我们可以采取以下方法:

(2) 使用 feed dict 初始化变量,或者
(3) 从检查点加载变量

我将首先介绍选项(3),因为它更容易,更好:

在您的 model_fn ,只需使用 Tensor 初始化一个变量由 tf.contrib.framework.load_variable 返回称呼。这需要:

  • 您的嵌入有一个有效的 TF 检查点
  • 您知道检查点内嵌入变量的完全限定名称。

  • 代码非常简单:
    def model_fn(mode, features, labels, hparams):
    embeddings = tf.Variable(tf.contrib.framework.load_variable(
    'gs://my-bucket/word2vec_checkpoints/',
    'a/fully/qualified/scope/embeddings'
    ))
    ....
    return tf.estimator.EstimatorSpec(...)

    但是,如果您的嵌入不是由另一个 TF 模型生成的,则此方法对您不起作用,因此选项 (2)。

    对于 (2),我们需要使用 tf.train.Scaffold 它本质上是一个配置对象,包含启动 tf.Session 的所有选项。 (由于很多原因,估计器故意隐藏)。

    您可以指定 Scaffold tf.train.EstimatorSpec 你返回你的 model_fn .

    我们在 model_fn 中创建一个占位符,并将其设为
    我们的嵌入变量的初始化操作,然后传递 init_feed_dict通过 Scaffold .例如
    def model_fn(mode, features, labels, hparams):
    embed_ph = tf.placeholder(
    shape=[hparams.vocab_size, hparams.embedding_size],
    dtype=tf.float32)
    embeddings = tf.Variable(embed_ph)
    # Define your model
    return tf.estimator.EstimatorSpec(
    ..., # normal EstimatorSpec args
    scaffold=tf.train.Scaffold(init_feed_dict={embed_ph: my_embedding_numpy_array})
    )

    这里发生的是 init_feed_dict将填充 embed_ph 的值运行时占位符,这将允许 embeddings.initialization_op (占位符的分配),运行。

    关于tensorflow - 加载预训练的 word2vec 以初始化 Estimator model_fn 中的 embedding_lookup,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44680769/

    25 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com