Skip to content

Conversation

merrymercy
Copy link
Contributor

@merrymercy merrymercy commented Nov 25, 2024

It makes torch.compile with tensor parallelism much faster.

llama 3.1 8b w/ tp = 8

python3 -m sglang.bench_offline_throughput --model meta-llama/Llama-3.1-8B-Instruct --dataset-name random --num-prompt 1 --random-input 1024 --random-output 256 --random-range 1 --tp 8 --enable-torch-compile

before: 315.02 token/s
after: 457.22 token/s

@merrymercy merrymercy merged commit c4336b2 into main Nov 25, 2024
14 of 16 checks passed
@merrymercy merrymercy deleted the pr-fix-torch-compile branch November 25, 2024 22:55
timethink pushed a commit to timethink/sglang that referenced this pull request Mar 9, 2025
@ZhuJiaqi9905
Copy link
Contributor

Since the setting of tp_group.ca_comm to None has been cancelled here, is it still necessary to backup ca_comm? Could we change the code of patch_model() to the following:

@contextmanager
def patch_model(
    model: torch.nn.Module,
    enable_compile: bool,
    num_tokens: int,
):
    """Patch the model to make it compatible with with torch.compile"""
   

    try:
        if enable_compile:
            _to_torch(model, reverse=False, num_tokens=num_tokens)
            
            # Use custom-allreduce here.
            # We found the custom allreduce is much faster than the built-in allreduce in torch,
            # even with ENABLE_INTRA_NODE_COMM=1.

            yield torch.compile(
                torch.no_grad()(model.forward),
                mode=os.environ.get(
                    "SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs"
                ),
                dynamic=False,
            )
        else:
            yield model.forward
    finally:
        if enable_compile:
            _to_torch(model, reverse=True, num_tokens=num_tokens)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants