gpt4 book ai didi

huggingface-transformers - 如何从 HuggingFace Longformer 中提取文档嵌入

转载 作者:行者123 更新时间:2023-12-04 11:40:17 28 4
gpt4 key购买 nike

想做类似的事情

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
(来自 this thread)
使用 longformer
文档示例似乎做了类似的事情,但令人困惑(尤其是如何设置注意力掩码,我假设我想将其设置为 [CLS] 标记,该示例将全局注意力设置为我认为的随机值)
>>> import torch
>>> from transformers import LongformerModel, LongformerTokenizer

>>> model = LongformerModel.from_pretrained('allenai/longformer-base-4096', return_dict=True)
>>> tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096')

>>> SAMPLE_TEXT = ' '.join(['Hello world! '] * 1000) # long input document
>>> input_ids = torch.tensor(tokenizer.encode(SAMPLE_TEXT)).unsqueeze(0) # batch of size 1

>>> # Attention mask values -- 0: no attention, 1: local attention, 2: global attention
>>> attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device) # initialize to local attention
>>> attention_mask[:, [1, 4, 21,]] = 2 # Set global attention based on the task. For example,
... # classification: the <s> token
... # QA: question tokens
... # LM: potentially on the beginning of sentences and paragraphs
>>> outputs = model(input_ids, attention_mask=attention_mask)
>>> sequence_output = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output
(来自 here)

最佳答案

您不需要弄乱这些值(除非您想优化 longformer 处理不同 token 的方式)。在上面列出的示例中,它将强制全局关注第 1 个、第 4 个和第 21 个 token 。他们在这里放置了随机数,但有时您可能希望全局参与某种类型的标记,例如标记序列中的问题标记(例如:<问题标记> + <答案标记> 但仅全局参与第一部分)。
如果您只是在寻找嵌入,您可以遵循讨论的内容 here : The last layers of longformer for document embeddings .

关于huggingface-transformers - 如何从 HuggingFace Longformer 中提取文档嵌入,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63708496/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com