-
Notifications
You must be signed in to change notification settings - Fork 30.3k
[Feature] Support using FlashAttention2 on Ascend NPU #36696
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Support using FlashAttention2 on Ascend NPU #36696
Conversation
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. When it is ready for review, please click the |
c75466b
to
0cb586a
Compare
a550c44
to
e7c35be
Compare
We validate this PR with LLaMA-Factory on Ascend NPU and GPU. We choose Llama-3-8b-sft-mixture model and default examples/train_lora/llama3_lora_sft.yaml with a new property After the whole training finish, the loss values are shown as follows, which is close as expected:
|
Does this PR support rhymes-ai/Aria? |
@SHYuanBest Through reading codes in rhymes-ai/Aria, I find that class AriaPretrainedModel is inherited from class PreTrainedModel in transformers. In that situation, it can use FlashAttention2 on Ascend NPU with this PR. If you can try, please tell me the result, looking forward for your reply :) |
please help me review the PR, thanks :) cc @SunMarc @ArthurZucker @Rocketknight1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can determine the implementation of flash attention in the modeling_flash_attn
file according to the hardware type, instead of introducing excessive modifications to the model file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's avoid updating each and every modeling:
is_flash_attn_2_available = is_flash_attn_2_available and is_torch_npu_available
is_flash_attn_greater_or_equal_2_10 = is_flash_attn_greater_or_equal_2_10 or is_npu_fa2_top_left_aligned_causal_mask
if you see what I mean!
e7c35be
to
61cc617
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function names are a bit ambiguous. Also, remember to read contributing guides for code format https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md
38058cb
to
ab988b2
Compare
@ArthurZucker Thanks for your suggestions! We have refreshed the commit and introduce The reason why we do not update logics in existing functions |
3db0abf
to
a977054
Compare
@ArthurZucker Thanks for your suggestion. I have renamed file name Besides, the CI error seems not related with my PR, so the code is ready for review and merge :) |
Can you just run |
@ArthurZucker I have ran command |
Thanks! 🤗 |
@FightingZhen Great to see FlashAttention2 supported on Ascend, but I'm wondering how to use this feature. when I run the code:
it throws an error of
I use |
cc @FightingZhen I don't have NPUs so I can't test anything, happy if you can answer! |
BTW we can do a patch for the commits that broke it in 4.55! |
Thank you for your attention, we truly appreciate your suggestion. However, to maintain consistency with the community's release plan, we do not recommend providing an additional patch specifically for version Additionally, while we currently cannot integrate Ascend NPU test cases into the transformers CI pipeline, we are actively working on building an independent test cases pipeline internally. This pipeline will periodically verify the functionality of Flash Attention 2 on Ascend NPU, once error found, we will fix it as fast as we can. |
I see, it works fine on
so not sure whether the flash attention implementation is working properly. |
When using |
https://github.com/huggingface/transformers/releases/tag/v4.55.1 the patch is out btw! |
@FightingZhen sorry about breaking NPU support :/ The only things that should've changed are the import and unpad/pad functions (which are taken from the fa3 implementations) Can you tell me which part(s) broke? Sadly, we cannot check with NPU hardware ourselves. |
I have committed a PR #40151 to fix this problem, you can find core reason in this PR if you are interested. |
@FightingZhen Can you provide a demo code? I tried model inference with from transformers import AutoModelForCausalLM, AutoTokenizer
from time import time
model_name = "/models/Qwen3-4B"
attn_implementation = "flash_attention_2"
# attn_implementation = "sdpa"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
attn_implementation=attn_implementation,
device_map="npu:0"
)
prompt = "Give me a introduction to large language models."
messages = [
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
start_time = time()
# conduct text completion
generated_ids = model.generate(
**model_inputs,
max_new_tokens=16384,
do_sample=False,
)
end_time = time()
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
content = tokenizer.decode(output_ids, skip_special_tokens=True)
print("content_len:", len(content))
print("attention implementation:", attn_implementation)
print("time cost:", end_time - start_time) and with similar time cost with |
I usually use LLaMA-Factory to test this feature, detail operations as follows:
|
What does this PR do?
Package
flash-attn
is not supported on Ascend NPU, even can not be installed. In that situation, we can not use FlashAttention2 with transformers natively 😞To solve this problem, we find from Ascend community Ascend FlashAttentionScore Document that Ascend
torch_npu
has provided FlashAttention2 apitorch_npu.npu_fusion_attention
on Ascend NPU, and has provided solutions for supporting flash-attention and flash-varlen-attention on Ascend NPU, which can play the same role as the corresponding apis in packageflash-attn
.Therefore, this PR is committed for supporting using FlashAttention2 on Ascend NPU.
Main Modification
src/transformers/integrations/npu_flash_attention.py
to organize necessary functions for using FlashAttention 2 on Ascend NPU. Because packageflash-attn
is unavailable on Ascend NPU, part of codes are copied from flash-attn/bert_padding.py.torch_npu
available on Ascend NPU, patchindex_first_axis
,pad_input
,unpad_input
,flash_attn_func
,flash_attn_varlen_func
andapply_rotary_emb
6 functions insrc/transformers/modeling_flash_attention_utils.py
online to replace corresponding functions in packageflash-attn
.is_flash_attn_available()
andflash_attn_supports_top_left_mask()
2 new functions insrc/transformers/modeling_flash_attention_utils.py
. The former is used for determining whether to use FlashAttention2, the latter is used for determining whether to use top-left mask or down-right one. In that situation, any other new platforms and backends only need to append FlashAttention2 and top-left mask related logics in both functions, which seems more intuitive than modifying every modeling files.src/transformers/models
with the new introducedis_flash_attn_available()
andflash_attn_supports_top_left_mask()
functions.Fixes: #36618
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.