gpt4 book ai didi

python - 加载预训练的 BERT 模型问题

转载 作者:行者123 更新时间:2023-12-05 06:50:17 29 4
gpt4 key购买 nike

我正在使用 Huggingface 进一步训练 BERT 模型。我使用两种方法保存模型:步骤 (1) 使用此代码保存整个模型:model.save_pretrained(save_location),步骤 (2) 使用此代码保存模型的 state_dict:torch.save(model.state_dict(),'model.pth')但是,当我尝试使用以下代码加载此预训练 BERT 模型时,bert_mask_lm = BertForMaskedLM.from_pretrained('save_location') 用于步骤 (1) 和 torch.load('model.pth' ) 对于步骤 (2),我在这两个步骤中都收到以下错误:

AttributeError                            Traceback (most recent call last)
~/anaconda3/lib/python3.6/site-packages/torch/serialization.py in _check_seekable(f)
307 try:
--> 308 f.seek(f.tell())
309 return True

AttributeError: 'torch._C.PyTorchFileReader' object has no attribute 'seek'

During handling of the above exception, another exception occurred:

步骤(1)的详细stacktrace如下:

AttributeError                            Traceback (most recent call last)
~/anaconda3/lib/python3.6/site-packages/torch/serialization.py in _check_seekable(f)
307 try:
--> 308 f.seek(f.tell())
309 return True

AttributeError: 'torch._C.PyTorchFileReader' object has no attribute 'seek'

During handling of the above exception, another exception occurred:

AttributeError Traceback (most recent call last)
~/anaconda3/lib/python3.6/site-packages/transformers/modeling_utils.py in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
1037 try:
-> 1038 state_dict = torch.load(resolved_archive_file, map_location="cpu")
1039 except Exception:

~/anaconda3/lib/python3.6/site-packages/torch/serialization.py in load(f, map_location, pickle_module, **pickle_load_args)
593 return torch.jit.load(opened_file)
--> 594 return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
595 return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)

~/anaconda3/lib/python3.6/site-packages/moxing/framework/file/file_io_patch.py in _load(f, map_location, pickle_module, **pickle_load_args)
199
--> 200 _check_seekable(f)
201 f_should_read_directly = _should_read_directly(f)

~/anaconda3/lib/python3.6/site-packages/torch/serialization.py in _check_seekable(f)
310 except (io.UnsupportedOperation, AttributeError) as e:
--> 311 raise_err_msg(["seek", "tell"], e)
312 return False

~/anaconda3/lib/python3.6/site-packages/torch/serialization.py in raise_err_msg(patterns, e)
303 + " try to load from it instead.")
--> 304 raise type(e)(msg)
305 raise e

AttributeError: 'torch._C.PyTorchFileReader' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead.

During handling of the above exception, another exception occurred:

OSError Traceback (most recent call last)
~/work/algo-FineTuningBert3/FineTuningBert3.py in <module>()
1 #Model load checking
----> 2 loadded_model = BertForMaskedLM.from_pretrained('/cache/raw_model/')

~/anaconda3/lib/python3.6/site-packages/transformers/modeling_utils.py in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
1039 except Exception:
1040 raise OSError(
-> 1041 f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' "
1042 f"at '{resolved_archive_file}'"
1043 "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "

OSError: Unable to load weights from pytorch checkpoint file for '/cache/raw_model/' at '/cache/raw_model/pytorch_model.bin'If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True.

我使用的是最新的 torch (1.7.1) 和 transformers (4.3.3) 包。我不清楚导致此错误的原因以及如何解决此问题。

最佳答案

我也经历过同样的事情。事实证明,这可能是由于 PyTorch 和 transformers 的版本无关。它必须是特定于版本的。

我在没有下载最新的 bert-base-uncased 模型的情况下使用了以下内容:

pip install torch==1.5.1
pip install transformers==3.0.2

MODEL_NAME = 'bert-base-uncased'
model = BertForTokenClassification.from_pretrained(
MODEL_NAME
)

这将自动下载关于合适版本的变压器的预训练 BERT 模型注意:我单独从官方网站明确下载了 vocab.txt,并将其与 BERT 分词器类一起使用。

关于python - 加载预训练的 BERT 模型问题,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66442648/

29 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com