Skip to content

Conversation

ch-wan
Copy link
Collaborator

@ch-wan ch-wan commented Feb 19, 2025

Motivation

This PR can partially address #3633.

Modifications

We reuse the memory of intermediate_cache1 to create intermediate_cache3.

Here is the test script

import torch
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe

N = 64 * 1024
E = 8
H = 4096
I = 8192

torch.manual_seed(0)

x = torch.randn((N, H), device="cuda", dtype=torch.float16) / 32
w1 = torch.randn((E, I * 2, H), device="cuda", dtype=torch.float16) / 32
w2 = torch.randn((E, H, I), device="cuda", dtype=torch.float16) / 32

gating_output = torch.randn((N, E), device="cuda", dtype=torch.float16)
topk = 2

x = fused_moe(x, w1, w2, gating_output, topk, True)

print(x)
print(torch.cuda.max_memory_allocated() // 1024 // 1024, "MB")

The output of the original implementation:

tensor([[-2.9869e-03,  3.7422e-03, -2.4395e-03,  ..., -2.0447e-03,
          8.6212e-03,  3.5362e-03],
        [-9.5520e-03,  6.5231e-03, -5.9586e-03,  ...,  1.5235e-04,
         -4.0359e-03,  5.0354e-03],
        [-5.5618e-03,  1.4296e-03, -6.3705e-03,  ..., -3.5400e-03,
         -4.6921e-03,  1.0918e-02],
        ...,
        [ 1.5354e-03,  7.7057e-03,  3.3035e-03,  ..., -1.1559e-03,
         -4.1962e-03, -1.9894e-03],
        [-9.8801e-03, -4.3716e-03,  8.8358e-04,  ...,  8.3847e-03,
         -8.6594e-04,  1.0101e-02],
        [-2.0733e-03,  9.3555e-04, -9.3162e-05,  ..., -1.1826e-03,
         -3.6907e-03, -4.7035e-03]], device='cuda:0', dtype=torch.float16)
9730 MB

PR reduces peak memory by 10.5%.

tensor([[-2.9869e-03,  3.7422e-03, -2.4395e-03,  ..., -2.0447e-03,
          8.6212e-03,  3.5362e-03],
        [-9.5520e-03,  6.5231e-03, -5.9586e-03,  ...,  1.5235e-04,
         -4.0359e-03,  5.0354e-03],
        [-5.5618e-03,  1.4296e-03, -6.3705e-03,  ..., -3.5400e-03,
         -4.6921e-03,  1.0918e-02],
        ...,
        [ 1.5354e-03,  7.7057e-03,  3.3035e-03,  ..., -1.1559e-03,
         -4.1962e-03, -1.9894e-03],
        [-9.8801e-03, -4.3716e-03,  8.8358e-04,  ...,  8.3847e-03,
         -8.6594e-04,  1.0101e-02],
        [-2.0733e-03,  9.3555e-04, -9.3162e-05,  ..., -1.1826e-03,
         -3.6907e-03, -4.7035e-03]], device='cuda:0', dtype=torch.float16)
8706 MB

Checklist

@zhyncs zhyncs merged commit 6b0aeb5 into sgl-project:main Feb 19, 2025
1 of 4 checks passed
@zhyncs
Copy link
Member

zhyncs commented Feb 19, 2025

Thanks!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants