Skip to content

Conversation

FightingZhen
Copy link
Contributor

@FightingZhen FightingZhen commented Mar 13, 2025

What does this PR do?

Package flash-attn is not supported on Ascend NPU, even can not be installed. In that situation, we can not use FlashAttention2 with transformers natively 😞

To solve this problem, we find from Ascend community Ascend FlashAttentionScore Document that Ascend torch_npu has provided FlashAttention2 api torch_npu.npu_fusion_attention on Ascend NPU, and has provided solutions for supporting flash-attention and flash-varlen-attention on Ascend NPU, which can play the same role as the corresponding apis in package flash-attn.

Therefore, this PR is committed for supporting using FlashAttention2 on Ascend NPU.

Main Modification

  1. Create a new file src/transformers/integrations/npu_flash_attention.py to organize necessary functions for using FlashAttention 2 on Ascend NPU. Because package flash-attn is unavailable on Ascend NPU, part of codes are copied from flash-attn/bert_padding.py.
  2. When detecting package torch_npu available on Ascend NPU, patch index_first_axis, pad_input, unpad_input, flash_attn_func, flash_attn_varlen_func and apply_rotary_emb 6 functions in src/transformers/modeling_flash_attention_utils.py online to replace corresponding functions in package flash-attn.
  3. Introduce is_flash_attn_available() and flash_attn_supports_top_left_mask() 2 new functions in src/transformers/modeling_flash_attention_utils.py. The former is used for determining whether to use FlashAttention2, the latter is used for determining whether to use top-left mask or down-right one. In that situation, any other new platforms and backends only need to append FlashAttention2 and top-left mask related logics in both functions, which seems more intuitive than modifying every modeling files.
  4. Refresh FlashAttention2 and top-left mask related codes in src/transformers/models with the new introduced is_flash_attn_available() and flash_attn_supports_top_left_mask() functions.

Fixes: #36618

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

@github-actions github-actions bot marked this pull request as draft March 13, 2025 11:56
Copy link
Contributor

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. When it is ready for review, please click the Ready for review button (at the bottom of the PR page).

@FightingZhen FightingZhen marked this pull request as ready for review March 13, 2025 12:55
@FightingZhen FightingZhen force-pushed the support-ascend-fa branch 8 times, most recently from a550c44 to e7c35be Compare March 14, 2025 14:10
@FightingZhen FightingZhen changed the title [Feature] Support using FlashAttention2 and FlashVarlenAttention on Ascend NPU [Feature] Support using FlashAttention2 on Ascend NPU Mar 14, 2025
@FightingZhen
Copy link
Contributor Author

We validate this PR with LLaMA-Factory on Ascend NPU and GPU. We choose Llama-3-8b-sft-mixture model and default examples/train_lora/llama3_lora_sft.yaml with a new property flash_attn=fa2 specified to activate FlashAttention2.

After the whole training finish, the loss values are shown as follows, which is close as expected:

  • GPU loss: 0.8756905022789451
  • NPU loss: 0.8757302094908321

@FightingZhen
Copy link
Contributor Author

@SHYuanBest
Copy link

Does this PR support rhymes-ai/Aria?

@FightingZhen
Copy link
Contributor Author

Does this PR support rhymes-ai/Aria?

@SHYuanBest Through reading codes in rhymes-ai/Aria, I find that class AriaPretrainedModel is inherited from class PreTrainedModel in transformers. In that situation, it can use FlashAttention2 on Ascend NPU with this PR. If you can try, please tell me the result, looking forward for your reply :)

@FightingZhen
Copy link
Contributor Author

please help me review the PR, thanks :) cc @SunMarc @ArthurZucker @Rocketknight1

Copy link
Contributor

@hiyouga hiyouga left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can determine the implementation of flash attention in the modeling_flash_attn file according to the hardware type, instead of introducing excessive modifications to the model file.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's avoid updating each and every modeling:
is_flash_attn_2_available = is_flash_attn_2_available and is_torch_npu_available
is_flash_attn_greater_or_equal_2_10 = is_flash_attn_greater_or_equal_2_10 or is_npu_fa2_top_left_aligned_causal_mask

if you see what I mean!

Copy link
Contributor

@hiyouga hiyouga left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function names are a bit ambiguous. Also, remember to read contributing guides for code format https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md

@FightingZhen FightingZhen force-pushed the support-ascend-fa branch 3 times, most recently from 38058cb to ab988b2 Compare March 20, 2025 12:56
@FightingZhen
Copy link
Contributor Author

FightingZhen commented Mar 20, 2025

Let's avoid updating each and every modeling: is_flash_attn_2_available = is_flash_attn_2_available and is_torch_npu_available is_flash_attn_greater_or_equal_2_10 = is_flash_attn_greater_or_equal_2_10 or is_npu_fa2_top_left_aligned_causal_mask

if you see what I mean!

@ArthurZucker Thanks for your suggestions! We have refreshed the commit and introduce is_flash_attn_available() and is_flash_attn_uses_top_left_mask() 2 new functions in src/transformers/modeling_flash_attention_utils.py. The former is used for determining whether to use FlashAttention2, the latter is used for determining whether to use top-left mask or down-right one. In that situation, any other new platforms and backends only need to append FlashAttention2 and top-left mask related logics in both functions, which seems more intuitive than modifying every modeling files.

The reason why we do not update logics in existing functions is_flash_attn_2_available and is_flash_attn_greater_or_equal_2_10 is that, we have found other repositories using above functions for determining whether package flash-attn is installed and version greater than 2.1.0. However, this package can not be installed on Ascend NPU 😞. Therefore, adding is_torch_npu_available() into above functions may trigger unexpected errors.

@FightingZhen
Copy link
Contributor Author

Thanks! One nit about the npu utils

@ArthurZucker Thanks for your suggestion. I have renamed file name npu_flash_attention_utils to npu_flash_attention, and move it into integrations directory.

Besides, the CI error seems not related with my PR, so the code is ready for review and merge :)

@ArthurZucker
Copy link
Collaborator

Can you just run make fix-copies? 🤗 I can then merge!

@FightingZhen
Copy link
Contributor Author

Can you just run make fix-copies? 🤗 I can then merge!

@ArthurZucker I have ran command make fix-copies, and fix qwen3 and qwen3_moe modular conversion mismatch. However, there are still 2 errors in CI, which still seems unrelated with my PR. Please help me make sure whether this PR can be merged :)

@ArthurZucker ArthurZucker merged commit e686fed into huggingface:main Mar 31, 2025
15 of 18 checks passed
@ArthurZucker
Copy link
Collaborator

Thanks! 🤗

dmdaksh pushed a commit to dmdaksh/transformers that referenced this pull request Apr 2, 2025
)

* [Feature] Support using flash-attention on Ascend NPU

* Fix qwen3 and qwen3_moe moduler conversion mismatch
zucchini-nlp pushed a commit to BakerBunker/transformers that referenced this pull request Apr 2, 2025
)

* [Feature] Support using flash-attention on Ascend NPU

* Fix qwen3 and qwen3_moe moduler conversion mismatch
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request May 14, 2025
)

* [Feature] Support using flash-attention on Ascend NPU

* Fix qwen3 and qwen3_moe moduler conversion mismatch
@Byter-s
Copy link

Byter-s commented Aug 7, 2025

Does this PR support rhymes-ai/Aria?

@SHYuanBest Through reading codes in rhymes-ai/Aria, I find that class AriaPretrainedModel is inherited from class PreTrainedModel in transformers. In that situation, it can use FlashAttention2 on Ascend NPU with this PR. If you can try, please tell me the result, looking forward for your reply :)

@FightingZhen Great to see FlashAttention2 supported on Ascend, but I'm wondering how to use this feature. when I run the code:

from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "Qwen/Qwen3-4B"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    attn_implementation="flash_attention_2",
    device_map="auto"
)

it throws an error of importlib.metadata.PackageNotFoundError: No package metadata was found for flash_attn:

Traceback (most recent call last):
  File "/mnt/ascend/flash_attn_test.py", line 8, in <module>
    model = AutoModelForCausalLM.from_pretrained(
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/ascend/conda/arc/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py", line 600, in from_pretrained
    return model_class.from_pretrained(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/ascend/conda/arc/lib/python3.11/site-packages/transformers/modeling_utils.py", line 316, in _wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/ascend/conda/arc/lib/python3.11/site-packages/transformers/modeling_utils.py", line 4986, in from_pretrained
    model = cls(config, *model_args, **model_kwargs)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/ascend/conda/arc/lib/python3.11/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 430, in __init__
    super().__init__(config)
  File "/mnt/ascend/conda/arc/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2227, in __init__
    self.config._attn_implementation_internal = self._check_and_adjust_attn_implementation(
                                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/ascend/conda/arc/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2758, in _check_and_adjust_attn_implementation
    return self.get_correct_attn_implementation(applicable_attn_implementation, is_init_check)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/ascend/conda/arc/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2784, in get_correct_attn_implementation
    self._flash_attn_2_can_dispatch(is_init_check)
  File "/mnt/ascend/conda/arc/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2491, in _flash_attn_2_can_dispatch
    flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/ascend/conda/arc/lib/python3.11/importlib/metadata/__init__.py", line 1009, in version
    return distribution(distribution_name).version
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/ascend/conda/arc/lib/python3.11/importlib/metadata/__init__.py", line 982, in distribution
    return Distribution.from_name(distribution_name)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/ascend/conda/arc/lib/python3.11/importlib/metadata/__init__.py", line 565, in from_name
    raise PackageNotFoundError(name)
importlib.metadata.PackageNotFoundError: No package metadata was found for flash_attn
[ERROR] 2025-08-07-18:07:43 (PID:1213252, Device:-1, RankID:-1) ERR99999 UNKNOWN applicaiton exception

I use transformers==4.55.0, torch==2.6.0, torch_npu==2.6.0rc1. Is this the right way to use it, or something wrong with my env?

@ArthurZucker
Copy link
Collaborator

cc @FightingZhen I don't have NPUs so I can't test anything, happy if you can answer!

@FightingZhen
Copy link
Contributor Author

@Byter-s Hi, Flash Attention 2 support on Ascend NPU is affected by other commits in 4.55.0, see #39844. Suggest using 4.52.4 version or latest main branch, please try it and tell me the result, thanks :)

@ArthurZucker
Copy link
Collaborator

BTW we can do a patch for the commits that broke it in 4.55!

@FightingZhen
Copy link
Contributor Author

BTW we can do a patch for the commits that broke it in 4.55!

Thank you for your attention, we truly appreciate your suggestion. However, to maintain consistency with the community's release plan, we do not recommend providing an additional patch specifically for version 4.55.0, it’s better to wait for the next official community release.

Additionally, while we currently cannot integrate Ascend NPU test cases into the transformers CI pipeline, we are actively working on building an independent test cases pipeline internally. This pipeline will periodically verify the functionality of Flash Attention 2 on Ascend NPU, once error found, we will fix it as fast as we can.

@FightingZhen
Copy link
Contributor Author

FightingZhen commented Aug 13, 2025

@Byter-s Hi, Flash Attention 2 support on Ascend NPU is affected by other commits in 4.55.0, see #39844. Suggest using 4.52.4 version or latest main branch, please try it and tell me the result, thanks :)

main branch is broken again due to #40002 💢 , just try 4.52.4

@Byter-s
Copy link

Byter-s commented Aug 13, 2025

I see, it works fine on 4.52.4 version. I've tried many cases but here is something weird:

  • using attn_implementation="flash_attention_2" doesn't accelerate the generation. The time cost to generate the same answer is almost the same as using attn_implementation="sdpa" (on 4.52.4)
  • flash_attention_2 even slower than not using it or sdpa (on 4.55.0, maybe affected by other commits?)

so not sure whether the flash attention implementation is working properly.

@FightingZhen
Copy link
Contributor Author

I see, it works fine on 4.52.4 version. I've tried many cases but here is something weird:

  • using attn_implementation="flash_attention_2" doesn't accelerate the generation. The time cost to generate the same answer is almost the same as using attn_implementation="sdpa" (on 4.52.4)
  • flash_attention_2 even slower than not using it or sdpa (on 4.55.0, maybe affected by other commits?)

so not sure whether the flash attention implementation is working properly.

When using transformers==4.52.4, if you can find message Detect using FlashAttention2 on Ascend NPU in your training logs, it means that you are using Flash Attention 2 on Ascend NPU.

@ArthurZucker
Copy link
Collaborator

https://github.com/huggingface/transformers/releases/tag/v4.55.1 the patch is out btw!

@vasqu
Copy link
Contributor

vasqu commented Aug 13, 2025

@FightingZhen sorry about breaking NPU support :/ The only things that should've changed are the import and unpad/pad functions (which are taken from the fa3 implementations)

Can you tell me which part(s) broke? Sadly, we cannot check with NPU hardware ourselves.

@FightingZhen FightingZhen deleted the support-ascend-fa branch August 14, 2025 01:52
@FightingZhen
Copy link
Contributor Author

@FightingZhen sorry about breaking NPU support :/ The only things that should've changed are the import and unpad/pad functions (which are taken from the fa3 implementations)

Can you tell me which part(s) broke? Sadly, we cannot check with NPU hardware ourselves.

I have committed a PR #40151 to fix this problem, you can find core reason in this PR if you are interested.

@Byter-s
Copy link

Byter-s commented Aug 14, 2025

When using transformers==4.52.4, if you can find message Detect using FlashAttention2 on Ascend NPU in your training logs, it means that you are using Flash Attention 2 on Ascend NPU.

@FightingZhen Can you provide a demo code? I tried model inference with transformers==4.52.4 but no such logs:

from transformers import AutoModelForCausalLM, AutoTokenizer
from time import time

model_name = "/models/Qwen3-4B"
attn_implementation = "flash_attention_2"
# attn_implementation = "sdpa"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    attn_implementation=attn_implementation,
    device_map="npu:0"
)

prompt = "Give me a introduction to large language models."
messages = [
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

start_time = time()
# conduct text completion
generated_ids = model.generate(
    **model_inputs,
    max_new_tokens=16384,
    do_sample=False,
)
end_time = time()
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() 

content = tokenizer.decode(output_ids, skip_special_tokens=True)

print("content_len:", len(content))
print("attention implementation:", attn_implementation)
print("time cost:", end_time - start_time)

and with similar time cost with sdpa. Is it the right way to use it?

@FightingZhen
Copy link
Contributor Author

When using transformers==4.52.4, if you can find message Detect using FlashAttention2 on Ascend NPU in your training logs, it means that you are using Flash Attention 2 on Ascend NPU.

@FightingZhen Can you provide a demo code? I tried model inference with transformers==4.52.4 but no such logs:

from transformers import AutoModelForCausalLM, AutoTokenizer
from time import time

model_name = "/models/Qwen3-4B"
attn_implementation = "flash_attention_2"
# attn_implementation = "sdpa"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    attn_implementation=attn_implementation,
    device_map="npu:0"
)

prompt = "Give me a introduction to large language models."
messages = [
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

start_time = time()
# conduct text completion
generated_ids = model.generate(
    **model_inputs,
    max_new_tokens=16384,
    do_sample=False,
)
end_time = time()
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() 

content = tokenizer.decode(output_ids, skip_special_tokens=True)

print("content_len:", len(content))
print("attention implementation:", attn_implementation)
print("time cost:", end_time - start_time)

and with similar time cost with sdpa. Is it the right way to use it?

I usually use LLaMA-Factory to test this feature, detail operations as follows:

  1. Clone LLaMA-Factory from main branch.
  2. Modify LLaMA-Factory code here from is_flash_attn_2_available to is_flash_attn_available, is_flash_attn_available is imported from transformers.modeling_flash_attention_utils.
  3. Install LLaMA-Factory from source code.
  4. Maybe export DISABLE_VERSION_CHECK=1 is required, ignore checking flash-attn package, which is unavailable on Ascend NPU.
  5. Use examples/train_lora/llama3_lora_sft.sh to test, suggest setting param --gradient_accumulation_steps to 1, and add --flash_attn fa2 to use Flash-Attention2 on Ascend NPU.

soghomon-b pushed a commit to soghomon-b/transformers that referenced this pull request Aug 24, 2025
)

* [Feature] Support using flash-attention on Ascend NPU

* Fix qwen3 and qwen3_moe moduler conversion mismatch
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Can not use flash-attention and flash-varlen-attention on Ascend NPU
6 participants