gpt4 book ai didi

python - 在 Tensorflow 中实现嵌入 Dropout

转载 作者:行者123 更新时间:2023-12-04 14:24:31 25 4
gpt4 key购买 nike

我正在阅读关于“Regularizing and Optimizing LSTM Language Models”的这篇论文,他们谈论Embedding Dropout,它说“由于丢失发生在用于完整前向和反向传递的嵌入矩阵上,这意味着特定单词的所有出现都将在该遍中消失,相当于对单热嵌入和嵌入查找之间的连接执行变分丢弃。”但是,我似乎无法在 tensorflow 实验中找到一种很好的方法来做到这一点。对于每个新批处理,我目前使用以下代码嵌入我的序列:

embedding_sequence = tf.contrib.layers.embed_sequence(features['input_sequence'], vocab_size=n_tokens, embed_dim=word_embedding_size)

现在我可以轻松地将 dropout 应用于 embedding_sequence,但是我读到的论文表明应该从整个正向/反向传递中丢弃相同的词。关于仍然允许我使用 embed_sequence 的简单方法的任何建议?这是我认为我的方法应该是分解后的 embed_sequence但我仍然不相信这是正确的...

建议的解决方案

embedding_matrix = tf.get_variable("embeddings", shape=[vocab_size, embed_dim], dtype = tf.float32, initializer = None, trainable=True)
embedding_matrix_dropout = tf.nn.dropout(embedding_matrix, keep_prob=keep_prob)
embedding_sequence = tf.nn.embedding_lookup(embedding_matrix_dropout, features['input_sequence'])

有没有更合适的方法来处理这个问题?是否有任何我从 embed_sequence 中得到但我无法从我提出的解决方案中得到的东西?

我不确定的次要事项:

  1. 我的 embedding_matrix 初始值设定项应该是什么?默认设置为无?
  2. tf.nn.dropout似乎按照论文中提到的 1/keep_prob 处理缩放是必要的,对吗?

最佳答案

你可以像这样使用嵌入dropouts..

with tf.variable_scope('embedding'):
self.embedding_matrix = tf.get_variable( "embedding", shape=[self.vocab_size, self.embd_size], dtype=tf.float32, initializer=self.initializer)

with tf.name_scope("embedding_dropout"):
self.embedding_matrix = tf.nn.dropout(self.embedding_matrix, keep_prob=self.embedding_dropout, noise_shape=[self.vocab_size,1])

with tf.name_scope('input'):
self.input_batch = tf.placeholder(tf.int64, shape=(None, None))
self.inputs = tf.nn.embedding_lookup(self.embedding_matrix, self.input_batch)

这随机将嵌入矩阵的行设置为零,如 https://arxiv.org/pdf/1512.05287.pdf 中所述你提到的论文中引用了它。

来源:

https://github.com/tensorflow/tensorflow/issues/14746

类似的pytorch实现:

https://github.com/salesforce/awd-lstm-lm/blob/master/embed_regularize.py

关于python - 在 Tensorflow 中实现嵌入 Dropout,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48693587/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com