gpt4 book ai didi

python - word2vec的tensorflow实现

转载 作者:太空狗 更新时间:2023-10-29 17:34:29 25 4
gpt4 key购买 nike

Tensorflow 教程 here指的是它们的基本实现,您可以在 github here 上找到,其中 Tensorflow 作者使用 Skipgram 模型实现 word2vec 向量嵌入训练/评估。

我的问题是关于 generate_batch() 函数中(目标、上下文)对的实际生成。

关于 this line Tensorflow 作者在单词滑动窗口中从“中心”单词索引中随机抽取附近的目标索引。

然而,他们also keep a data structure targets_to_avoid他们首先向其中添加“中心”上下文词(当然我们不想对其进行采样),但在我们添加它们之后还会添加其他词。

我的问题如下:

  1. 为什么要围绕这个词从这个滑动窗口中采样,为什么不只是有一个循环并使用它们而不是采样?他们担心 word2vec_basic.py(他们的“基本”实现)中的性能/内存似乎很奇怪。
  2. 无论 1) 的答案是什么,为什么他们采样并跟踪他们使用targets_to_avoid 选择的内容?如果他们想要真正的随机,他们会使用带替换的选择,如果他们想确保获得所有选项,他们应该使用循环并首先获得所有选项!
  3. 是否内置 tf.models.embedding.gen_word2vec也这样工作?如果是这样,我在哪里可以找到源代码? (在 Github 仓库中找不到 .py 文件)

谢谢!

最佳答案

我尝试了您提出的生成批处理的方法 - 有一个循环并使用整个跳过窗口。结果是:

<强>1。更快地生成批处理

对于 128 的批量大小和 5 的跳过窗口

  • 通过逐个遍历数据生成批处理每 10,000 个批处理花费 0.73s
  • 使用教程代码和 num_skips=2 生成批处理每 10,000 个批处理需要 3.59s

<强>2。过度拟合的风险更高

保持教程代码的其余部分不变,我用两种方式训练模型并记录每 2000 步的平均损失:

enter image description here

这种模式反复出现。它表明每个单词使用 10 个样本而不是 2 个样本会导致过度拟合。

这是我用来生成批处理的代码。它取代了教程的 generate_batch 函数。

data_index = 0

def generate_batch(batch_size, skip_window):
global data_index
batch = np.ndarray(shape=(batch_size), dtype=np.int32) # Row
labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32) # Column

# For each word in the data, add the context to the batch and the word to the labels
batch_index = 0
while batch_index < batch_size:
context = data[get_context_indices(data_index, skip_window)]

# Add the context to the remaining batch space
remaining_space = min(batch_size - batch_index, len(context))
batch[batch_index:batch_index + remaining_space] = context[0:remaining_space]
labels[batch_index:batch_index + remaining_space] = data[data_index]

# Update the data_index and the batch_index
batch_index += remaining_space
data_index = (data_index + 1) % len(data)

return batch, labels

编辑:get_context_indices 是一个简单的函数,它返回 skip_window 中 data_index 周围的索引切片。查看slice() documentation了解更多信息。

关于python - word2vec的tensorflow实现,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/38111129/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com