Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions examples/controlnet/train_controlnet_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,9 +473,6 @@ def parse_args(input_args=None):
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
parser.add_argument(
"--enable_npu_flash_attention", action="store_true", help="Whether or not to use npu flash attention."
)
parser.add_argument(
"--set_grads_to_none",
action="store_true",
Expand Down Expand Up @@ -970,13 +967,6 @@ def load_model_hook(models, input_dir):
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)

if args.enable_npu_flash_attention:
if is_torch_npu_available():
logger.info("npu flash attention enabled.")
flux_transformer.enable_npu_flash_attention()
else:
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu devices.")

if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
import xformers
Expand Down
11 changes: 9 additions & 2 deletions src/diffusers/models/transformers/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
FluxAttnProcessor2_0,
FluxAttnProcessor2_0_NPU,
FusedFluxAttnProcessor2_0,
FusedFluxAttnProcessor2_0_NPU,
)
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
Expand Down Expand Up @@ -141,7 +142,10 @@ def __init__(
self.norm1_context = AdaLayerNormZero(dim)

if hasattr(F, "scaled_dot_product_attention"):
processor = FluxAttnProcessor2_0()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's keep the default logic simple i.e. remove the changes from this file
you can use NPU with set_attn_processor, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yiyixuxu In the attention processor, it has NPU FA, but only with AttnProcessorNPU. Or do you want me to add a function in attention_processor.py to change the processor there? I think right now the selection is only in the init, so it shouldn't be any logic changes. Please let me know what you think so that I can modify based on that. Thanks

In the attention processor, it has NPU FA, but only with AttnProcessorNPU. Or do you want me to add a function in attention_processor.py to change the processor there? I think right now the selection is only in the init, so it shouldn't be any logic changes. Please let me know what you think so that I can modify based on that. Thanks

if is_torch_npu_available():
processor = FluxAttnProcessor2_0_NPU()
else:
processor = FluxAttnProcessor2_0()
else:
raise ValueError(
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
Expand Down Expand Up @@ -407,7 +411,10 @@ def fuse_qkv_projections(self):
if isinstance(module, Attention):
module.fuse_projections(fuse=True)

self.set_attn_processor(FusedFluxAttnProcessor2_0())
if is_torch_npu_available():
self.set_attn_processor(FusedFluxAttnProcessor2_0_NPU())
else:
self.set_attn_processor(FusedFluxAttnProcessor2_0())

# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
Expand Down