gpt4 book ai didi

python - embedding在tensorflow中有什么作用

转载 作者:太空狗 更新时间:2023-10-30 01:48:00 26 4
gpt4 key购买 nike

我正在阅读将 RNN 与 tensorflow 结合使用的示例:ptb_word_lm.py

我不知道 embeddingembedding_lookup 在这里做什么。它如何为张量添加另一个维度?从 (20, 25) 到 (20, 25, 200)。在这种情况下,(20,25) 是一个批量大小为 20 的 25 个时间步长。我不明白您如何/为什么可以将单元格的 hidden_​​size 添加为输入数据的维度?通常,输入数据是大小为 [batch_size, num_features] 的矩阵,模型会将 num_features ---> hidden_​​dims 映射到矩阵大小为 [num_features, hidden_​​dims] 产生大小为 [batch-size, hidden-dims] 的输出。那么 hidden_​​dims 如何成为输入张量的维度呢?

input_data, targets = reader.ptb_producer(train_data, 20, 25)
cell = tf.nn.rnn_cell.BasicLSTMCell(200, forget_bias=1.0, state_is_tuple=True)
initial_state = cell.zero_state(20, tf.float32)
embedding = tf.get_variable("embedding", [10000, 200], dtype=tf.float32)
inputs = tf.nn.embedding_lookup(embedding, input_data)

input_data_train # <tf.Tensor 'PTBProducer/Slice:0' shape=(20, 25) dtype=int32>
inputs # <tf.Tensor 'embedding_lookup:0' shape=(20, 25, 200) dtype=float32>

outputs = []
state = initial_state
for time_step in range(25):
if time_step > 0:
tf.get_variable_scope().reuse_variables()

cell_output, state = cell(inputs[:, time_step, :], state)
outputs.append(cell_output)

output = tf.reshape(tf.concat(1, outputs), [-1, 200])

outputs # list of 20: <tf.Tensor 'BasicLSTMCell/mul_2:0' shape=(20, 200) dtype=float32>
output # <tf.Tensor 'Reshape_2:0' shape=(500, 200) dtype=float32>

softmax_w = tf.get_variable("softmax_w", [config.hidden_size, config.vocab_size], dtype=tf.float32)
softmax_b = tf.get_variable("softmax_b", [config.hidden_size, config.vocab_size], dtype=tf.float32)
logits = tf.matmul(output, softmax_w) + softmax_b

loss = tf.nn.seq2seq.sequence_loss_by_example([logits], [tf.reshape(targets, [-1])],[tf.ones([20*25], dtype=tf.float32)])
cost = tf.reduce_sum(loss) / batch_size

最佳答案

好的,我不会尝试解释这段特定的代码,但我会尝试回答“什么是嵌入?”标题的一部分。

基本上它是将原始输入数据映射到一组实值维度,并且组织原始输入数据在这些维度中的“位置”以改进任务。

在 tensorflow 中,如果你想象一些文本输入字段有“king”、“queen”、“girl”、“boy”,并且你有 2 个嵌入维度。希望反向传播能够训练嵌入,将皇室概念放在一个轴上,将性别放在另一个轴上。因此,在这种情况下,4 分类值特征被“简化”为具有 2 维的浮点嵌入特征。

它们是使用查找表实现的,可以从原始表或字典排序中散列。对于一个训练有素的人,你可能会输入“Queen”,然后你会说 [1.0,1.0],输入“Boy”然后你会说 [0.0,0.0]。

Tensorflow 将错误反向传播到这个查找表中,希望从随机初始化的字典开始的东西会逐渐变成我们上面看到的那样。

希望这对您有所帮助。如果没有,请查看:http://colah.github.io/posts/2014-07-NLP-RNNs-Representations/

关于python - embedding在tensorflow中有什么作用,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/40184537/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com