gpt4 book ai didi

python - 如何加载部分预训练的 pytorch 模型?

转载 作者:行者123 更新时间:2023-12-03 19:14:31 25 4
gpt4 key购买 nike

我正在尝试让 pytorch 模型在句子分类任务上运行。由于我正在处理医疗笔记,我正在使用 ClinicalBert ( https://github.com/kexinhuang12345/clinicalBERT ) 并希望使用其预先训练的权重。不幸的是,ClinicalBert 模型只将文本分类为 1 个二进制标签,而我有 281 个二进制标签。因此,我试图实现此代码 https://github.com/kaushaltrivedi/bert-toxic-comments-multilabel/blob/master/toxic-bert-multilabel-classification.ipynb其中 bert 之后的最终分类器长度为 281。

如何在不加载分类权重的情况下从 ClinicalBert 模型加载预训练的 Bert 权重?

天真地尝试从预训练的 ClinicalBert 权重中加载权重,我收到以下错误:

size mismatch for classifier.weight: copying a param with shape torch.Size([2, 768]) from checkpoint, the shape in current model is torch.Size([281, 768]).
size mismatch for classifier.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([281]).

我目前尝试从 pytorch_pretrained_bert 包中替换 from_pretrained 函数,并像这样弹出分类器权重和偏差:
def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs):
...
if state_dict is None:
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
state_dict = torch.load(weights_path, map_location='cpu')
state_dict.pop('classifier.weight')
state_dict.pop('classifier.bias')
old_keys = []
new_keys = []
...

我收到以下错误消息:
信息-modeling_diagnosis-BertForMultiLabelSequenceClassification 的权重未从预训练模型初始化:['classifier.weight', 'classifier.bias']

最后,我想从临床伯特预训练权重中加载 bert 嵌入,并随机初始化顶级分类器权重。

最佳答案

在加载之前删除状态字典中的键是一个好的开始。假设您正在使用 nn.Module.load_state_dict 要加载预训练的权重,您还需要设置 strict=False参数以避免因意外或丢失的键而导致错误。这将忽略 state_dict 中不存在于模型中的条目(意外键),更重要的是,将使用默认初始化(缺少键)保留缺失的条目。为了安全起见,您可以检查该方法的返回值,以验证有问题的权重是缺失键的一部分,并且没有任何意外的键。

关于python - 如何加载部分预训练的 pytorch 模型?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61211685/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com