gpt4 book ai didi

python - Keras Word2Vec 实现

转载 作者:太空狗 更新时间:2023-10-30 01:11:23 25 4
gpt4 key购买 nike

我正在使用 http://adventuresinmachinelearning.com/word2vec-keras-tutorial/ 中的实现学习一些关于 word2Vec 的知识。我不明白的是为什么损失函数没有减少?

Iteration 119200, loss=0.7305528521537781
Iteration 119300, loss=0.6254740953445435
Iteration 119400, loss=0.8255964517593384
Iteration 119500, loss=0.7267132997512817
Iteration 119600, loss=0.7213149666786194
Iteration 119700, loss=0.6156617999076843
Iteration 119800, loss=0.11473365128040314
Iteration 119900, loss=0.6617216467857361

根据我的理解,网络是此任务中使用的标准网络:

input_target = Input((1,))
input_context = Input((1,))

embedding = Embedding(vocab_size, vector_dim, input_length=1, name='embedding')

target = embedding(input_target)
target = Reshape((vector_dim, 1))(target)
context = embedding(input_context)
context = Reshape((vector_dim, 1))(context)

dot_product = Dot(axes=1)([target, context])
dot_product = Reshape((1,))(dot_product)
output = Dense(1, activation='sigmoid')(dot_product)

model = Model(inputs=[input_target, input_context], outputs=output)
model.compile(loss='binary_crossentropy', optimizer='rmsprop') #adam??

单词来自 http://mattmahoney.net/dc/text8.zip 中大小为 10000 的词汇表(英文文本)

我注意到有些词在某种程度上是及时学习的,比如数字和文章的上下文很容易被猜到,但损失从一开始就停留在 0.7 左右,随着迭代的进行,它只是随机波动。

training部分是这样做的(没有标准的fit方法感觉很奇怪)

arr_1 = np.zeros((1,))
arr_2 = np.zeros((1,))
arr_3 = np.zeros((1,))
for cnt in range(epochs):
idx = np.random.randint(0, len(labels)-1)
arr_1[0,] = word_target[idx]
arr_2[0,] = word_context[idx]
arr_3[0,] = labels[idx]
loss = model.train_on_batch([arr_1, arr_2], arr_3)
if cnt % 100 == 0:
print("Iteration {}, loss={}".format(cnt, loss))

我是否遗漏了关于此类网络的一些重要信息?没写的完全按照上面的链接实现

最佳答案

我遵循了相同的教程,算法再次通过样本后损失下降了。请注意,损失函数仅针对当前目标和上下文词对计算。在本教程的代码示例中,一个时期只有一个样本,因此您需要的目标和上下文词的数量要多才能达到损失下降的程度。

我用下面这行实现了训练部分

model.fit([word_target, word_context], labels, epochs=5)

请注意,这可能需要很长时间,具体取决于语料库的大小。 train_on_batch功能使您可以更好地控制训练,您可以改变批量大小或选择您在训练的每个步骤中选择的样本。

关于python - Keras Word2Vec 实现,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51039800/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com