Skip to content

[Bug] MLA slower than default for small context long outputs *and* generating bad output reproducibly  #3716

@pseudotensor

Description

@pseudotensor

Checklist

  • 1. I have searched related issues but cannot get the expected help.
  • 2. The bug has not been fixed in the latest version.
  • 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
  • 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/sglang/discussions/new/choose Otherwise, it will be closed.
  • 5. Please use English, otherwise it will be closed.

Describe the bug

Enabling MLA is much slower. I show 2 benchmarks, one is concurrency and the other is token input size (since MLA is supposed to be better there).

Reproduction

docker stop v3 ; docker remove v3
docker run -d --gpus all --shm-size 32g -p 5000:5000 --name v3 -v ~/.cache/huggingface:/root/.cache/huggingface --ipc=host lmsysorg/sglang:v0.4.3.post2-cu125 \
    python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3 --tp 8 \
    --enable-torch-compile \
    --enable-flashinfer-mla \
    --trust-remote-code --port 5000 --host 0.0.0.0 \
    --api-key EMPTY  --random-seed 1234

or:

docker stop r1 ; docker remove r1
docker run -d --gpus all --shm-size 32g -p 5000:5000 --name r1 -v ~/.cache/huggingface:/root/.cache/huggingface --ipc=host lmsysorg/sglang:v0.4.3.post2-cu125 \
    python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-R1 --tp 8 \
    --enable-torch-compile \
    --enable-flashinfer-mla \
    --trust-remote-code --port 5000 --host 0.0.0.0 \
    --api-key EMPTY  --random-seed 1234

Then on client:

import os
os.environ['HUGGING_FACE_HUB_TOKEN'] = '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'

import time
import openai
import concurrent.futures
from transformers import AutoTokenizer
from datetime import datetime


def count_tokens(prompt, model_name="deepseek-ai/DeepSeek-V3"):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return len(tokenizer.encode(prompt))

def measure_performance(prompt, max_tokens=8192):
    url = 'SET'
    api_key = 'SET'
    client = openai.Client(base_url=f"http://{url}:5000/v1", api_key=api_key)

    token_count = count_tokens(prompt)
    start_time = time.time()

    response = client.chat.completions.create(
        model="deepseek-ai/DeepSeek-R1",
        messages=[{"role": "user", "content": prompt}],
        temperature=0,
        max_tokens=max_tokens,
        stream=True  # Enable streaming mode
    )

    first_token_time = None
    total_tokens = 0
    first_token_received = False

    for chunk in response:
        if not first_token_received:
            first_token_time = time.time() - start_time
            first_token_received = True
        total_tokens += len(chunk.choices[0].delta.content.split())

    total_time = time.time() - start_time
    tps = total_tokens / total_time if total_time > 0 else 0

    return {
        "concurrent_requests": None,  # Placeholder, to be updated in multi-run function
        "prompt_length": token_count,
        "max_tokens": max_tokens,
        "time_to_first_token": first_token_time,
        "tokens_per_second": tps,
        "total_time": total_time,
        "total_tokens": total_tokens,
    }

def run_concurrent_tests(prompt, num_requests_list, max_tokens=8192):
    results = []

    for num_requests in num_requests_list:
        print("num_requests: %s" % num_requests, flush=True)
        with concurrent.futures.ThreadPoolExecutor(max_workers=num_requests) as executor:
            futures = [executor.submit(measure_performance, prompt, max_tokens) for _ in range(num_requests)]
            concurrent_results = [future.result() for future in concurrent.futures.as_completed(futures)]

            avg_time_to_first_token = sum(res["time_to_first_token"] for res in concurrent_results) / num_requests
            avg_tokens_per_second = sum(res["tokens_per_second"] for res in concurrent_results) / num_requests
            avg_total_time = sum(res["total_time"] for res in concurrent_results) / num_requests
            total_tokens = sum(res["total_tokens"] for res in concurrent_results)

            results.append({
                "concurrent_requests": num_requests,
                "prompt_length": concurrent_results[0]["prompt_length"],
                "max_tokens": max_tokens,
                "time_to_first_token": avg_time_to_first_token,
                "tokens_per_second": avg_tokens_per_second,
                "total_time": avg_total_time,
                "total_tokens": total_tokens,
            })

    return results

def generate_markdown(results):
    md_report = """# Concurrent Request Performance Analysis

## Summary of Response Time and Throughput

| Concurrent Requests | Prompt Length | Max Tokens | Time to First Token (s) | Tokens per Second | Total Time (s) | Total Tokens |
|--------------------|--------------|------------|-------------------------|-------------------|---------------|-------------|
"""

    for res in results:
        md_report += f"| {res['concurrent_requests']} | {res['prompt_length']} | {res['max_tokens']} | {res['time_to_first_token']:.4f} | {res['tokens_per_second']:.4f} | {res['total_time']:.4f} | {res['total_tokens']} |\n"

    return md_report

def main():
    prompt = "Write an extremely long story."
    num_requests_list = [1, 8, 16, 32]  # Different levels of concurrency

    results = run_concurrent_tests(prompt, num_requests_list)

    markdown_report = generate_markdown(results)

    with open("concurrent_performance_report.md", "w") as f:
        f.write(markdown_report)

    print("Concurrent performance report generated: concurrent_performance_report.md")

if __name__ == "__main__":
    main()

gives:

Image

Compared to without MHA on 0.4.2:

#3196 (comment)

i.e.

Image

Then for context testing:

import time
import openai
from transformers import AutoTokenizer
from datetime import datetime

def count_tokens(prompt, model_name="deepseek-ai/DeepSeek-V3"):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return len(tokenizer.encode(prompt))

def measure_performance(prompt, model="deepseek-ai/DeepSeek-V3", max_tokens=64):
    url = 'SET'
    api_key = 'SET'
    client = openai.Client(base_url=f"http://{url}:5000/v1", api_key=api_key)

    token_count = count_tokens(prompt, model)
    start_time = time.time()

    response = client.chat.completions.create(
        model=model,
        messages=[{"role": "user", "content": prompt}],
        temperature=0,
        max_tokens=max_tokens,
        stream=True  # Enable streaming mode
    )

    first_token_time = None
    total_tokens = 0
    first_token_received = False

    for chunk in response:
        if not first_token_received:
            first_token_time = time.time() - start_time
            first_token_received = True
        total_tokens += len(chunk.choices[0].delta.content.split())

    total_time = time.time() - start_time
    tps = total_tokens / total_time if total_time > 0 else 0

    return {
        "prompt_length": token_count,
        "max_tokens": max_tokens,
        "time_to_first_token": first_token_time,
        "tokens_per_second": tps,
        "total_time": total_time,
        "total_tokens": total_tokens,
    }

def generate_markdown(results):
    md_report = """# Token Performance Analysis

## Summary of Response Time and Throughput

| Prompt Length | Max Tokens | Time to First Token (s) | Tokens per Second | Total Time (s) | Total Tokens |
|--------------|------------|-------------------------|-------------------|---------------|-------------|
"""

    for res in results:
        md_report += f"| {res['prompt_length']} | {res['max_tokens']} | {res['time_to_first_token']:.4f} | {res['tokens_per_second']:.4f} | {res['total_time']:.4f} | {res['total_tokens']} |\n"

    return md_report

def main():
    test_cases = [
        ("Write an extremely long story.", 8192),
        ("word " * 8000 + "Write an extremely long story.", 8192),
        ("word " * 118000 + "Write an extremely long story.", 8192)
    ]

    results = []
    for prompt, max_tokens in test_cases:
        res = measure_performance(prompt, max_tokens=max_tokens)
        results.append(res)

    markdown_report = generate_markdown(results)

    with open("performance_report.md", "w") as f:
        f.write(markdown_report)

    print("Performance report generated: performance_report.md")

if __name__ == "__main__":
    main()

One gets:

Image

Before without MHA on 0.4.2 gave: #3196 (comment)
i.e.

Image

i.e. in particular, long context is much faster, about 3x at longer contexts compared to 0.4.2.

However, for small context inputs, it dropped from 47 tokens/sec to 40 tokens/sec.

Environment

INFO 02-20 03:00:02 __init__.py:190] Automatically detected platform cuda.
Python: 3.10.12 (main, Jan 17 2025, 14:35:34) [GCC 11.4.0]
CUDA available: True
GPU 0,1,2,3,4,5,6,7: NVIDIA H200
GPU 0,1,2,3,4,5,6,7 Compute Capability: 9.0
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 12.4, V12.4.131
CUDA Driver Version: 550.144.03
PyTorch: 2.5.1+cu124
sgl_kernel: 0.0.3.post6
flashinfer: 0.2.1.post2+cu124torch2.5
triton: 3.1.0
transformers: 4.48.3
torchao: 0.8.0
numpy: 1.26.4
aiohttp: 3.11.12
fastapi: 0.115.8
hf_transfer: 0.1.9
huggingface_hub: 0.28.1
interegular: 0.3.3
modelscope: 1.23.0
orjson: 3.10.15
packaging: 24.2
psutil: 7.0.0
pydantic: 2.10.6
multipart: 0.0.20
zmq: 26.2.1
uvicorn: 0.34.0
uvloop: 0.21.0
vllm: 0.7.2
openai: 1.63.2
tiktoken: 0.9.0
anthropic: 0.45.2
decord: 0.6.0
NVIDIA Topology: 
        GPU0    GPU1    GPU2    GPU3    GPU4    GPU5    GPU6    GPU7    NIC0    NIC1    NIC2    NIC3    NIC4    NIC5    NIC6    NIC7    NIC8    NIC9    NIC10   NIC11   CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      NV18    NV18    NV18    NV18    NV18    NV18    NV18    NODE    NODE    NODE    NODE    SYS     NODE    PIX     SYS     SYS     SYS     SYS     SYS     0-95,192-287    0               N/A
GPU1    NV18     X      NV18    NV18    NV18    NV18    NV18    NV18    NODE    PIX     PHB     PHB     SYS     NODE    NODE    SYS     SYS     SYS     SYS     SYS     0-95,192-287    0               N/A
GPU2    NV18    NV18     X      NV18    NV18    NV18    NV18    NV18    PIX     NODE    NODE    NODE    SYS     NODE    NODE    SYS     SYS     SYS     SYS     SYS     0-95,192-287    0               N/A
GPU3    NV18    NV18    NV18     X      NV18    NV18    NV18    NV18    NODE    NODE    NODE    NODE    SYS     PIX     NODE    SYS     SYS     SYS     SYS     SYS     0-95,192-287    0               N/A
GPU4    NV18    NV18    NV18    NV18     X      NV18    NV18    NV18    SYS     SYS     SYS     SYS     NODE    SYS     SYS     NODE    NODE    NODE    PIX     NODE    96-191,288-383  1               N/A
GPU5    NV18    NV18    NV18    NV18    NV18     X      NV18    NV18    SYS     SYS     SYS     SYS     NODE    SYS     SYS     NODE    NODE    NODE    NODE    PIX     96-191,288-383  1               N/A
GPU6    NV18    NV18    NV18    NV18    NV18    NV18     X      NV18    SYS     SYS     SYS     SYS     NODE    SYS     SYS     PIX     PHB     PHB     NODE    NODE    96-191,288-383  1               N/A
GPU7    NV18    NV18    NV18    NV18    NV18    NV18    NV18     X      SYS     SYS     SYS     SYS     PIX     SYS     SYS     NODE    NODE    NODE    NODE    NODE    96-191,288-383  1               N/A
NIC0    NODE    NODE    PIX     NODE    SYS     SYS     SYS     SYS      X      NODE    NODE    NODE    SYS     NODE    NODE    SYS     SYS     SYS     SYS     SYS
NIC1    NODE    PIX     NODE    NODE    SYS     SYS     SYS     SYS     NODE     X      PHB     PHB     SYS     NODE    NODE    SYS     SYS     SYS     SYS     SYS
NIC2    NODE    PHB     NODE    NODE    SYS     SYS     SYS     SYS     NODE    PHB      X      PIX     SYS     NODE    NODE    SYS     SYS     SYS     SYS     SYS
NIC3    NODE    PHB     NODE    NODE    SYS     SYS     SYS     SYS     NODE    PHB     PIX      X      SYS     NODE    NODE    SYS     SYS     SYS     SYS     SYS
NIC4    SYS     SYS     SYS     SYS     NODE    NODE    NODE    PIX     SYS     SYS     SYS     SYS      X      SYS     SYS     NODE    NODE    NODE    NODE    NODE
NIC5    NODE    NODE    NODE    PIX     SYS     SYS     SYS     SYS     NODE    NODE    NODE    NODE    SYS      X      NODE    SYS     SYS     SYS     SYS     SYS
NIC6    PIX     NODE    NODE    NODE    SYS     SYS     SYS     SYS     NODE    NODE    NODE    NODE    SYS     NODE     X      SYS     SYS     SYS     SYS     SYS
NIC7    SYS     SYS     SYS     SYS     NODE    NODE    PIX     NODE    SYS     SYS     SYS     SYS     NODE    SYS     SYS      X      PHB     PHB     NODE    NODE
NIC8    SYS     SYS     SYS     SYS     NODE    NODE    PHB     NODE    SYS     SYS     SYS     SYS     NODE    SYS     SYS     PHB      X      PIX     NODE    NODE
NIC9    SYS     SYS     SYS     SYS     NODE    NODE    PHB     NODE    SYS     SYS     SYS     SYS     NODE    SYS     SYS     PHB     PIX      X      NODE    NODE
NIC10   SYS     SYS     SYS     SYS     PIX     NODE    NODE    NODE    SYS     SYS     SYS     SYS     NODE    SYS     SYS     NODE    NODE    NODE     X      NODE
NIC11   SYS     SYS     SYS     SYS     NODE    PIX     NODE    NODE    SYS     SYS     SYS     SYS     NODE    SYS     SYS     NODE    NODE    NODE    NODE     X 

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NIC Legend:

  NIC0: mlx5_0
  NIC1: mlx5_1
  NIC2: mlx5_2
  NIC3: mlx5_3
  NIC4: mlx5_4
  NIC5: mlx5_5
  NIC6: mlx5_6
  NIC7: mlx5_7
  NIC8: mlx5_8
  NIC9: mlx5_9
  NIC10: mlx5_10
  NIC11: mlx5_11


ulimit soft: 1048576

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingdeepseek

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions