Skip to content

[Efficiency] The llama model with flash attention is slower than that without flash attention #26990

@KexinFeng

Description

@KexinFeng

System Info

The test ran with this fix applied: #26984

- `transformers` version: 4.34.0
- Platform: Linux-5.15.0-1045-aws-x86_64-with-glibc2.31
- Python version: 3.9.18
- Huggingface_hub version: 0.17.3
- Safetensors version: 0.4.0
- Accelerate version: 0.23.0
- Accelerate config:    not found
- PyTorch version (GPU?): 2.0.1+cu118 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>

Who can help?

@ArthurZucker and @younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

The model loading:

def get_model_tokenizer(model_id, flash_attn=False):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_id_or_path = "huggyllama/llama-7b"
    model = AutoModelForCausalLM.from_pretrained(
        model_id_or_path, device_map='auto' if device.type == 'cuda' else 'cpu',
        use_flash_attention_2=flash_attn)
    lm_block = HuggingfaceBlock(model)
    tokenizer = AutoTokenizer.from_pretrained(model_id_or_path,
                                              padding_side='left')
    tokenizer.pad_token = "[PAD]"

    return lm_block, tokenizer

Input_length = 760
batch_size = 13
Max_gen_token = [300, 100, 50, 20]

When `flash_attn==True':

token_latency: [18.3 ms/token, 20.7 ms/token, 26.4 ms/token , 44.1 ms/token ]

When 'flash_attn' == False':

token_latency: [14.1 ms/token, 17.8 ms/token, 24.3 ms/token , 44.2 ms/token ]

Expected behavior

Flash attention should accelerate the inference.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions