gpt4 book ai didi

python - 属性错误 : 'GPT2Model' object has no attribute 'gradient_checkpointing'

转载 作者:行者123 更新时间:2023-12-02 22:45:23 31 4
gpt4 key购买 nike

我最初尝试在 Flask 中加载 GPT2 微调模型。在初始化函数期间使用以下方法加载模型:

app.modelgpt2 = torch.load('models/model_gpt2.pt', map_location=torch.device('cpu'))
app.modelgpt2tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

但是在按照下面的代码片段执行预测任务时:

from flask import current_app
input_ids = current_app.modelgpt2tokenizer.encode("sample sentence here", return_tensors='pt')
sample_outputs = current_app.modelgpt2.generate(input_ids,
do_sample=True,
top_k=50,
min_length=30,
max_length=300,
top_p=0.95,
temperature=0.7,
num_return_sequences=1)

它抛出问题中提到的以下错误:AttributeError: 'GPT2Model' 对象没有属性 'gradient_checkpointing'

The error trace is listed starting from the model.generate function:File "/venv/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_contextreturn func(*args, **kwargs)

File "/venv/lib/python3.8/site-packages/transformers/generation_utils.py", line 1017, in generatereturn self.sample(

File "/venv/lib/python3.8/site-packages/transformers/generation_utils.py", line 1531, in sampleoutputs = self(

File "/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_implreturn forward_call(*input, **kwargs)

File "/venv/lib/python3.8/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 1044, in forwardtransformer_outputs = self.transformer(

File "/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_implreturn forward_call(*input, **kwargs)

File "/venv/lib/python3.8/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 861, in forwardprint(self.gradient_checkpointing)

File "/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1177, in getattrraise AttributeError("'{}' object has no attribute '{}'".format(

AttributeError: 'GPT2Model' object has no attribute 'gradient_checkpointing'

通过modeling_gpt2.py检查,默认情况下self.gradient_checkpointing在类的构造函数中设置为False

最佳答案

仅当框架使用 venv 或部署框架(如 uWSGI 或 gunicorn)运行时,才会发现此问题。当使用 transformers 版本 4.10.0 而不是最新包时,此问题得到解决。

关于python - 属性错误 : 'GPT2Model' object has no attribute 'gradient_checkpointing' ,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/69773687/

31 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com