Skip to content

Conversation

BBuf
Copy link
Collaborator

@BBuf BBuf commented Dec 9, 2024

  • add bf16 qwen2-57b-a14b tuning config for tp2/tp4 in A800.
  • add a fused_moe_triton unittest for bf16, fp16 and fp8_w8a8.
  • refine fused_moe benchmark readme.md

When I wan't to deploy qwen2-57b-a14b model in A800 with fp8, the error happens:

图片

The reason is that in Triton, the Ampere architecture currently doesn't support the fp8e4nv dtype. To detect this situation early, I added the fused_moe test mentioned above, and verify whether the fused_moe operator can work properly on the current GPU by checking the GPU architecture information.


def _test_case(self, m, n, k, e, topk, dtype, use_fp8_w8a8=False):
if use_fp8_w8a8:
# AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the GPU is Ampere architecture, we should fallback from fused_moe to either naive implementation or torch.compile implementation to prevent errors.

@merrymercy merrymercy merged commit 3844feb into sgl-project:main Dec 9, 2024
0 of 14 checks passed
timethink pushed a commit to timethink/sglang that referenced this pull request Mar 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants