gpt4 book ai didi

pytorch - 无法从 'flash_attn_func' 导入名称 'flash_attn'

转载 作者:行者123 更新时间:2023-12-03 07:50:26 33 4
gpt4 key购买 nike

尝试加载 llama2 模型:

model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map=device_map
)

使用这些 bnb_config:

BitsAndBytesConfig {
"bnb_4bit_compute_dtype": "bfloat16",
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_use_double_quant": true,
"llm_int8_enable_fp32_cpu_offload": false,
"llm_int8_has_fp16_weight": false,
"llm_int8_skip_modules": null,
"llm_int8_threshold": 6.0,
"load_in_4bit": true,
"load_in_8bit": false,
"quant_method": "bitsandbytes"
}

我收到此错误:

RuntimeError: Failed to import transformers.models.llama.modeling_llama because of the following error (look up to see its traceback):
cannot import name 'flash_attn_func' from 'flash_attn' (/opt/conda/lib/python3.10/site-packages/flash_attn/__init__.py)

任何帮助都会有帮助。

最佳答案

我在微调 llama2 模型时遇到了同样的错误,解决方案是恢复到以前版本的 Transformer。

pip install transformers==4.33.1 --upgrade

这应该有效。

关于pytorch - 无法从 'flash_attn_func' 导入名称 'flash_attn',我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/77283770/

33 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com