gpt4 book ai didi

tensorflow - 如何在 HuggingFace Transformers 库中获得预训练 BERT 模型的中间层输出?

转载 作者:行者123 更新时间:2023-12-04 00:57:10 25 4
gpt4 key购买 nike

(我正在关注关于 BERT 词嵌入的 this pytorch 教程,在教程中,作者访问了 BERT 模型的中间层。)

我想要的是使用 HuggingFace 的 Transformers 库访问 TensorFlow2 中 BERT 模型的单个输入 token 的最后 4 层。因为每一层输出一个长度为768的向量,所以最后4层的形状为4*768=3072 (对于每个 token )。

如何在 TF/keras/TF2 中实现这一点,以获得输入 token 的预训练模型的中间层? (稍后我将尝试获取句子中每个标记的标记,但现在一个标记就足够了)。

我正在使用 HuggingFace 的 BERT 模型:

!pip install transformers
from transformers import (TFBertModel, BertTokenizer)

bert_model = TFBertModel.from_pretrained("bert-base-uncased") # Automatically loads the config
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
sentence_marked = "hello"
tokenized_text = bert_tokenizer.tokenize(sentence_marked)
indexed_tokens = bert_tokenizer.convert_tokens_to_ids(tokenized_text)

print (indexed_tokens)
>> prints [7592]

输出是一个 token ( [7592] ),它应该是 BERT 模型的输入。

最佳答案

BERT 模型输出的第三个元素是一个元组,它由嵌入层的输出以及中间层隐藏状态组成。来自 documentation :

hidden_states (tuple(tf.Tensor), optional, returned when config.output_hidden_states=True): tuple of tf.Tensor (one for the output of the embeddings + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

Hidden-states of the model at the output of each layer plus the initial embedding outputs.



对于 bert-base-uncased模型, config.output_hidden_states默认为 True .因此,要访问 12 个中间层的隐藏状态,您可以执行以下操作:
outputs = bert_model(input_ids, attention_mask)
hidden_states = outputs[2][1:]
hidden_states中有12个元素从头到尾所有层对应的元组,每一层都是一个形状为 (batch_size, sequence_length, hidden_size)的数组.因此,例如,要访问批处理中所有样本的第五个 token 的第三层隐藏状态,您可以执行以下操作: hidden_states[2][:,4] .

请注意,如果您正在加载的模型默认不返回隐藏状态,那么您可以使用 BertConfig 加载配置。上课并通过 output_hidden_state=True论证,像这样:
config = BertConfig.from_pretrained("name_or_path_of_model",
output_hidden_states=True)

bert_model = TFBertModel.from_pretrained("name_or_path_of_model",
config=config)

关于tensorflow - 如何在 HuggingFace Transformers 库中获得预训练 BERT 模型的中间层输出?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61465103/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com