Skip to content

[Bug] Llama4 OOM with 400k input request #5212

@CatherineSue

Description

@CatherineSue

Checklist

  • 1. I have searched related issues but cannot get the expected help.
  • 2. The bug has not been fixed in the latest version.
  • 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
  • 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/sglang/discussions/new/choose Otherwise, it will be closed.
  • 5. Please use English, otherwise it will be closed.

Describe the bug

I started a server on 8xH100 with meta-llama/Llama-4-Scout-17B-16E-Instruct with the following command:

python sglang.launch_server --model meta-llama/Llama-4-Scout-17B-16E-Instruct \
--port 8080 \
--tp-size 8 \
--chat-template llama-4 \
--attention-backend=fa3 \
--mem-fraction-static=0.8 \
--context-length 1000000 

Then sent a request with around 400k input will cause CUDA OOM:

[2025-04-09 17:19:56] Received sigquit from a child process. It usually means the child failed.
[2025-04-09 17:19:56 TP5] Scheduler hit an exception: Traceback (most recent call last):
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 2045, in run_scheduler_process
    scheduler.event_loop_normal()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 608, in event_loop_normal
    result = self.run_batch(batch)
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1395, in run_batch
    logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 175, in forward_batch_generation
    logits_output = self.model_runner.forward(forward_batch)
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 1001, in forward
    return self.forward_extend(
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 962, in forward_extend
    return self.model.forward(
  File "/sgl-workspace/sglang/python/sglang/srt/models/mllama4.py", line 83, in forward
    hs = general_mm_embed_routine(
  File "/sgl-workspace/sglang/python/sglang/srt/managers/mm_utils.py", line 354, in general_mm_embed_routine
    inputs_embeds = embed_tokens(input_ids)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/sgl-workspace/sglang/python/sglang/srt/layers/vocab_parallel_embedding.py", line 482, in forward
    output_parallel = self.quant_method.embedding(self, masked_input.long())
  File "/sgl-workspace/sglang/python/sglang/srt/layers/vocab_parallel_embedding.py", line 62, in embedding
    return F.embedding(input_, layer.weight)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py", line 2551, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 3.43 GiB. GPU 5 has a total capacity of 79.44 GiB of which 2.64 GiB is free. Process 679812 has 76.79 GiB memory in use. Of the allocated memory 72.76 GiB is allocated by PyTorch, with 26.38 MiB allocated in private pools (e.g., CUDA Graphs), and 293.55 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

[2025-04-09 17:19:56 TP6] Scheduler hit an exception: Traceback (most recent call last):
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 2045, in run_scheduler_process
    scheduler.event_loop_normal()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 608, in event_loop_normal
    result = self.run_batch(batch)
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1395, in run_batch
    logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 175, in forward_batch_generation
    logits_output = self.model_runner.forward(forward_batch)
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 1001, in forward
    return self.forward_extend(
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 962, in forward_extend
    return self.model.forward(
  File "/sgl-workspace/sglang/python/sglang/srt/models/mllama4.py", line 83, in forward
    hs = general_mm_embed_routine(
  File "/sgl-workspace/sglang/python/sglang/srt/managers/mm_utils.py", line 354, in general_mm_embed_routine
    inputs_embeds = embed_tokens(input_ids)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/sgl-workspace/sglang/python/sglang/srt/layers/vocab_parallel_embedding.py", line 482, in forward
    output_parallel = self.quant_method.embedding(self, masked_input.long())
  File "/sgl-workspace/sglang/python/sglang/srt/layers/vocab_parallel_embedding.py", line 62, in embedding
    return F.embedding(input_, layer.weight)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py", line 2551, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 3.43 GiB. GPU 6 has a total capacity of 79.44 GiB of which 2.64 GiB is free. Process 679813 has 76.79 GiB memory in use. Of the allocated memory 72.76 GiB is allocated by PyTorch, with 26.38 MiB allocated in private pools (e.g., CUDA Graphs), and 293.55 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

[2025-04-09 17:19:56] Received sigquit from a child process. It usually means the child failed.

Reproduction

Start server on a 8xH100:

python sglang.launch_server --model meta-llama/Llama-4-Scout-17B-16E-Instruct \
--port 8080 \
--tp-size 8 \
--chat-template llama-4 \
--attention-backend=fa3 \
--context-length 1000000 

Run python3 send_llama_request.py

Below is the content of send_llama_request.py

import requests
import json

payload = {
    "model": "sgl-model",
    "messages": [
        {
            "role": "user",
            "content": [
            {
                "type": "text",
                "text": "1 " * 200000
            },
            ]
        }
    ],
    "max_tokens": 200,
    "temperature": 0.0,
    "top_p": 0.75,
    "top_k": -1,
    "stream": True,
    "stream_options": {
        "include_usage": True,
    },
    "ignore_eos": True,
}

# Send the POST request
response = requests.post(
    "http://localhost:8080/v1/chat/completions",
    # "http://localhost:9922/v1/chat/completions",
    headers={"Content-Type": "application/json", "opc-request-id": "xfrjoiwejfioewngrinel"},
    json=payload,
    stream=True
)

generated_text = ""

# Check if the response was successful
if response.status_code == 200:
    for chunk in response.iter_lines(chunk_size=None):
        print(chunk)
        chunk = chunk.strip()
        if not chunk:
            continue
        stem = "data: " 
        chunk = chunk[len(stem) :]
        if chunk == b"[DONE]":
            continue

        data = json.loads(chunk)
        if "error" in data:
            error_msg = data["error"]["message"]
            error_response_code = data["error"]["code"]
            raise RuntimeError(data["error"]["message"])

        delta = data["choices"][0]["delta"]
        if delta.get("content", None):
            generated_text += delta["content"]
        
    print("Generated text:", generated_text)
    print("Status:", response.status_code)
else:
    print("Error:", response.status_code, response.text)
    print(response.json())

Environment

Python: 3.10.16 (main, Dec  4 2024, 08:53:37) [GCC 9.4.0]
CUDA available: True
GPU 0,1,2,3,4,5,6,7: NVIDIA H100 80GB HBM3
GPU 0,1,2,3,4,5,6,7 Compute Capability: 9.0
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 12.4, V12.4.131
CUDA Driver Version: 560.35.03
PyTorch: 2.5.1+cu124
sglang: 0.4.5
sgl_kernel: 0.0.8
flashinfer: 0.1.6+cu124torch2.4
triton: 3.1.0
transformers: 4.51.0
torchao: 0.10.0
numpy: 1.26.4
aiohttp: 3.11.11
fastapi: 0.115.6
hf_transfer: 0.1.8
huggingface_hub: 0.30.1
interegular: 0.3.3
modelscope: 1.21.1
orjson: 3.10.13
outlines: 0.0.46
packaging: 24.2
psutil: 6.1.1
pydantic: 2.10.4
multipart: Module Not Found
zmq: Module Not Found
uvicorn: 0.34.0
uvloop: 0.21.0
vllm: 0.6.4.post1
xgrammar: 0.1.17
openai: 1.59.3
tiktoken: 0.7.0
anthropic: 0.42.0
litellm: 1.56.10
decord: 0.6.0
NVIDIA Topology: 
        GPU0    GPU1    GPU2    GPU3    GPU4    GPU5    GPU6    GPU7    NIC0    NIC1    NIC2    NIC3    NIC4    NIC5    NIC6    NIC7    NIC8    NIC9    NIC10   NIC11   NIC12   NIC13   NIC14NIC15   NIC16   NIC17   CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      NV18    NV18    NV18    NV18    NV18    NV18    NV18    PXB     PXB     NODE    NODE    NODE    NODE    NODE    NODE    NODE    SYS     SYS     SYS     SYS     SYS     SYS SYS      SYS     SYS     0-55,112-167    0               N/A
GPU1    NV18     X      NV18    NV18    NV18    NV18    NV18    NV18    NODE    NODE    NODE    PXB     PXB     NODE    NODE    NODE    NODE    SYS     SYS     SYS     SYS     SYS     SYS SYS      SYS     SYS     0-55,112-167    0               N/A
GPU2    NV18    NV18     X      NV18    NV18    NV18    NV18    NV18    NODE    NODE    NODE    NODE    NODE    PXB     PXB     NODE    NODE    SYS     SYS     SYS     SYS     SYS     SYS SYS      SYS     SYS     0-55,112-167    0               N/A
GPU3    NV18    NV18    NV18     X      NV18    NV18    NV18    NV18    NODE    NODE    NODE    NODE    NODE    NODE    NODE    PXB     PXB     SYS     SYS     SYS     SYS     SYS     SYS SYS      SYS     SYS     0-55,112-167    0               N/A
GPU4    NV18    NV18    NV18    NV18     X      NV18    NV18    NV18    SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     PXB     PXB     NODE    NODE    NODE    NODENODE     NODE    NODE    56-111,168-223  1               N/A
GPU5    NV18    NV18    NV18    NV18    NV18     X      NV18    NV18    SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     NODE    NODE    NODE    PXB     PXB     NODENODE     NODE    NODE    56-111,168-223  1               N/A
GPU6    NV18    NV18    NV18    NV18    NV18    NV18     X      NV18    SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     NODE    NODE    NODE    NODE    NODE    PXB PXB      NODE    NODE    56-111,168-223  1               N/A
GPU7    NV18    NV18    NV18    NV18    NV18    NV18    NV18     X      SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     NODE    NODE    NODE    NODE    NODE    NODENODE     PXB     PXB     56-111,168-223  1               N/A
NIC0    PXB     NODE    NODE    NODE    SYS     SYS     SYS     SYS      X      PIX     NODE    NODE    NODE    NODE    NODE    NODE    NODE    SYS     SYS     SYS     SYS     SYS     SYS SYS      SYS     SYS
NIC1    PXB     NODE    NODE    NODE    SYS     SYS     SYS     SYS     PIX      X      NODE    NODE    NODE    NODE    NODE    NODE    NODE    SYS     SYS     SYS     SYS     SYS     SYS SYS      SYS     SYS
NIC2    NODE    NODE    NODE    NODE    SYS     SYS     SYS     SYS     NODE    NODE     X      NODE    NODE    NODE    NODE    NODE    NODE    SYS     SYS     SYS     SYS     SYS     SYS SYS      SYS     SYS
NIC3    NODE    PXB     NODE    NODE    SYS     SYS     SYS     SYS     NODE    NODE    NODE     X      PIX     NODE    NODE    NODE    NODE    SYS     SYS     SYS     SYS     SYS     SYS SYS      SYS     SYS
NIC4    NODE    PXB     NODE    NODE    SYS     SYS     SYS     SYS     NODE    NODE    NODE    PIX      X      NODE    NODE    NODE    NODE    SYS     SYS     SYS     SYS     SYS     SYS SYS      SYS     SYS
NIC5    NODE    NODE    PXB     NODE    SYS     SYS     SYS     SYS     NODE    NODE    NODE    NODE    NODE     X      PIX     NODE    NODE    SYS     SYS     SYS     SYS     SYS     SYS SYS      SYS     SYS
NIC6    NODE    NODE    PXB     NODE    SYS     SYS     SYS     SYS     NODE    NODE    NODE    NODE    NODE    PIX      X      NODE    NODE    SYS     SYS     SYS     SYS     SYS     SYS SYS      SYS     SYS
NIC7    NODE    NODE    NODE    PXB     SYS     SYS     SYS     SYS     NODE    NODE    NODE    NODE    NODE    NODE    NODE     X      PIX     SYS     SYS     SYS     SYS     SYS     SYS SYS      SYS     SYS
NIC8    NODE    NODE    NODE    PXB     SYS     SYS     SYS     SYS     NODE    NODE    NODE    NODE    NODE    NODE    NODE    PIX      X      SYS     SYS     SYS     SYS     SYS     SYS SYS      SYS     SYS
NIC9    SYS     SYS     SYS     SYS     PXB     NODE    NODE    NODE    SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS      X      PIX     NODE    NODE    NODE    NODENODE     NODE    NODE
NIC10   SYS     SYS     SYS     SYS     PXB     NODE    NODE    NODE    SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     PIX      X      NODE    NODE    NODE    NODENODE     NODE    NODE
NIC11   SYS     SYS     SYS     SYS     NODE    NODE    NODE    NODE    SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     NODE    NODE     X      NODE    NODE    NODENODE     NODE    NODE
NIC12   SYS     SYS     SYS     SYS     NODE    PXB     NODE    NODE    SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     NODE    NODE    NODE     X      PIX     NODENODE     NODE    NODE
NIC13   SYS     SYS     SYS     SYS     NODE    PXB     NODE    NODE    SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     NODE    NODE    NODE    PIX      X      NODENODE     NODE    NODE
NIC14   SYS     SYS     SYS     SYS     NODE    NODE    PXB     NODE    SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     NODE    NODE    NODE    NODE    NODE     X  PIX      NODE    NODE
NIC15   SYS     SYS     SYS     SYS     NODE    NODE    PXB     NODE    SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     NODE    NODE    NODE    NODE    NODE    PIX  X       NODE    NODE
NIC16   SYS     SYS     SYS     SYS     NODE    NODE    NODE    PXB     SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     NODE    NODE    NODE    NODE    NODE    NODENODE      X      PIX
NIC17   SYS     SYS     SYS     SYS     NODE    NODE    NODE    PXB     SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     NODE    NODE    NODE    NODE    NODE    NODENODE     PIX      X 

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NIC Legend:

  NIC0: mlx5_0
  NIC1: mlx5_1
  NIC2: mlx5_2
  NIC3: mlx5_3
  NIC4: mlx5_4
  NIC5: mlx5_5
  NIC6: mlx5_6
  NIC7: mlx5_7
  NIC8: mlx5_8
  NIC9: mlx5_9
  NIC10: mlx5_10
  NIC11: mlx5_11
  NIC12: mlx5_12
  NIC13: mlx5_13
  NIC14: mlx5_14
  NIC15: mlx5_15
  NIC16: mlx5_16
  NIC17: mlx5_17


ulimit soft: 65535

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions