gpt4 book ai didi

python - tensorflow 基本word2vec示例: Shouldn't we be using weights [nce_weight Transpose] for the representation and not embedding matrix?

转载 作者:太空宇宙 更新时间:2023-11-03 15:17:01 25 4
gpt4 key购买 nike

我指的是this sample code
在下面的代码片段中:

embeddings = tf.Variable(tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0))
embed = tf.nn.embedding_lookup(embeddings, train_inputs)

# Construct the variables for the NCE loss
nce_weights = tf.Variable(tf.truncated_normal([vocabulary_size, embedding_size],stddev=1.0 / math.sqrt(embedding_size)))
nce_biases = tf.Variable(tf.zeros([vocabulary_size]))

loss = tf.reduce_mean(
tf.nn.nce_loss(weights=nce_weights,
biases=nce_biases,
labels=train_labels,
inputs=embed,
num_sampled=num_sampled,
num_classes=vocabulary_size))

optimizer = tf.train.GradientDescentOptimizer(1.0).minimize(loss)

现在NCE_Loss函数只不过是一个在optput层带有softmax的单隐藏层神经网络[知道只需要几个负样本]

图的这一部分只会更新网络的权重,它不会对“嵌入”矩阵/张量做任何事情。

因此,理想情况下,一旦网络经过训练,我们必须首先再次将其通过 embeddings_matrix 一次,然后乘以“nce_weights”的转置[将其视为在输入层和输出层的相同权重自动编码器]以达到每个单词的隐藏层表示,我们称之为 word2vec (?)

但是如果查看代码的后面部分,embeddings 矩阵的值将用于单词表示。 This

甚至是tensorflow doc for NCE loss ,提到输入(我们向其传递 embed,它使用 embeddings)作为第一层输入激活值。

inputs: A Tensor of shape [batch_size, dim]. The forward activations of the input network.

正常的反向传播在网络的第一层停止,NCE 损失的这种实现是否超出了损失范围并将损失传播到输入值(从而传播到嵌入)?

这似乎是一个额外的步骤? Refer this对于为什么我称之为额外的步骤,他有同样的解释。

最佳答案

我想要弄清楚阅读和浏览 tensorflow 是这样的

虽然整个东西是单隐藏层神经网络,但实际上是一个自动编码器。但权重并没有绑定(bind),这是我假设的。

编码器由权重矩阵embeddings组成,解码器由nce_weights组成。现在embed只不过是隐藏层输出,通过将输入与embeddings相乘得到。

因此,embeddingsnce_weights 都将在图中更新。我们可以选择两个权重矩阵中的任何一个,这里更优选embeddings

编辑1:

实际上,对于tf.nn.nce_losstf.nn.sampled_softmax_loss来说,参数、权重和偏差都是针对输入Weights(tranpose) X +偏差,目标函数,可以是逻辑回归/softmax函数 [refer] .

但是反向传播/梯度下降会一直发生到您正在构建的图的最底部,并且不仅仅停留在函数的权重和偏差上。因此,tf.nn.nce_losstf.nn.sampled_softmax_loss 中的 input 参数也会更新,而该参数又是 的构建>嵌入矩阵。

关于python - tensorflow 基本word2vec示例: Shouldn't we be using weights [nce_weight Transpose] for the representation and not embedding matrix?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43804684/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com