gpt4 book ai didi

python - 来自 Transformers 的 BertForSequenceClassification 的大小不匹配和多类问题

转载 作者:行者123 更新时间:2023-12-05 03:37:41 24 4
gpt4 key购买 nike

我刚刚在由电子商务网站的产品和标签(部门)组成的数据集上训练了一个 BERT 模型。这是一个多类问题。我使用 BertForSequenceClassification 来预测每个产品的部门。我在训练和评估中拆分它,我使用了 pytorch 的数据加载器,并且我得到了很好的分数,没有过度拟合。

现在我想在一个新的数据集上尝试它,以检查它如何处理看不见的数据。但我无法加载模型并应用于新数据集。我收到以下错误:

RuntimeError: Error(s) in loading state_dict for BertForSequenceClassification:
size mismatch for classifier.weight: copying a param with shape torch.Size([59, 1024]) from checkpoint, the shape in current model is torch.Size([105, 1024]).
size mismatch for classifier.bias: copying a param with shape torch.Size([59]) from checkpoint, the shape in current model is torch.Size([105]).

我发现问题可能是两个数据集之间的标签大小不匹配。我搜索了一下,找到了使用 ignore_mismatched_sizes=True 作为 pretrained 参数的建议。但我一直收到同样的错误。

这是我尝试预测未见数据时的部分代码:

from transformers import BertForSequenceClassification

# Just right before the actual usage select your hardware
device = torch.device('cuda') # or cpu
model = model.to(device) # send your model to your hardware



model = BertForSequenceClassification.from_pretrained("neuralmind/bert-large-portuguese-cased",
num_labels=len(label_dict),
output_attentions=False,
output_hidden_states=False,
ignore_mismatched_sizes=True)

model.to(device)

model.load_state_dict(torch.load('finetuned_BERT_epoch_2_full-Copy1.model', map_location=torch.device('cuda')))

_, predictions, true_vals = evaluate(dataloader_validation)
accuracy_per_class(predictions, true_vals)

有人可以帮我解决吗?我不知道我还能做什么!

任何帮助我都非常感谢!

最佳答案

您的新数据集有 105 个类别,而您的模型已针对 59 个类别进行训练。正如您已经提到的,您可以使用 ignore_mismatched_sizes加载你的模型。此参数将加载模型的嵌入层和编码层,但会随机初始化分类头:

model = BertForSequenceClassification.from_pretrained("finetuned_BERT_epoch_2_full-Copy1.model",
num_labels=105,
output_attentions=False,
output_hidden_states=False,
ignore_mismatched_sizes=True)

如果想保留59个标签的分类层,增加46个标签,可以引用这个answer .另请注意此答案的评论,因为由于新标签的随机初始化,此方法不会提供任何有意义的结果。

关于python - 来自 Transformers 的 BertForSequenceClassification 的大小不匹配和多类问题,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/69194640/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com