gpt4 book ai didi

nlp - 为什么我们在 HuggingFace BART 的生成过程中需要一个 decoder_start_token_id?

转载 作者:行者123 更新时间:2023-12-04 17:23:29 48 4
gpt4 key购买 nike

在 HuggingFace 代码的生成阶段: https://github.com/huggingface/transformers/blob/master/src/transformers/generation_utils.py#L88-L100

他们传入一个 decoder_start_token_id ,我不确定他们为什么需要这个。在 BART 配置中,decoder_start_token_id实际上是 2 ( https://huggingface.co/facebook/bart-base/blob/main/config.json ),这是句子标记 </s> 的结尾.

我尝试了一个简单的例子:

from transformers import *

import torch
model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
input_ids = torch.LongTensor([[0, 894, 213, 7, 334, 479, 2]])
res = model.generate(input_ids, num_beams=1, max_length=100)

print(res)

preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True).strip() for g in res]
print(preds)

我得到的结果:

tensor([[  2,   0, 894, 213,   7, 334, 479,   2]])
['He go to school.']

虽然不影响最终的“tokenization decoding”结果。但对我来说,我们生成的第一个标记实际上是 2 似乎很奇怪。 (</s>)。

最佳答案

您可以在代码中看到 encoder-decoder models解码器的输入标记从原始标记右移(参见函数 shift_tokens_right)。这意味着要猜测的第一个标记始终是 BOS(句子开头)。您可以检查示例中是否属于这种情况。

为了让解码器理解这一点,我们必须选择一个始终跟在 BOS 后面的第一个标记,那么它可能是哪个?老板?显然不是,因为它后面必须跟着常规标记。填充 token ?也不是一个好的选择,因为它后面跟着另一个填充标记或 EOS(句子结尾)。那么,EOS 呢?好吧,这是有道理的,因为它从来没有跟在训练集中的任何东西后面,所以没有下一个标记会发生冲突。此外,句子的开头跟在另一个句子的结尾不是很自然吗?

关于nlp - 为什么我们在 HuggingFace BART 的生成过程中需要一个 decoder_start_token_id?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64904840/

48 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com