-
Notifications
You must be signed in to change notification settings - Fork 2.9k
feat: support DeepSeek-R1-W4AFP8 model with ep-moe mode #7762
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @yangsijia-serena, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces support for the DeepSeek-R1-W4AFP8 model, focusing on its mixed-precision quantization for MoE layers. It integrates highly optimized CUTLASS kernels for efficient execution of int4 weights and fp8 activations, enhancing the system's capability to handle advanced quantized models with improved performance.
Highlights
- DeepSeek-R1-W4AFP8 Model Support: I've added comprehensive support for running the DeepSeek-R1-W4AFP8 model in
ep-moe
(Expert Parallelism Mixture of Experts) mode. This model utilizes a mixed-precision quantization scheme, with dense layers using FP8_BLOCK_SCALING and MoE layers employing INT4 weights and FP8 activation. - CUTLASS W4A8 MoE Kernel Integration: A new
cutlass_w4a8_moe
kernel has been integrated, enabling efficient w4a8-quantized MoE layer computations. This kernel leverages CUTLASS grouped GEMM for optimized matrix multiplications, including pre- and post-reordering of inputs/outputs and handling SiLU and Mul activation functions. - Quantization Method Extension: The system's quantization framework has been extended to recognize and support the new
w4afp8
quantization method. This includes updates to model configuration parsing, allowing the system to automatically detect and apply the correct quantization method based onhf_quant_config.json
. - Expert Parallelism Enhancements: Modifications were made to the
ep_moe
layer to properly handleW4AFp8Config
, including determining expert mapping across ranks and correctly loading input and weight scales for the mixed-precision MoE layers. New Triton kernels were added for efficient pre-processing of inputs for the CUTLASS MoE operation. - Performance and Accuracy Benchmarks: Initial benchmarks demonstrate significant improvements in both throughput and latency when running DeepSeek-R1-W4AFP8 on 8H20 with ep8, compared to DeepSeek-R1 on 16H20 with ep16. Accuracy evaluations on MMLU, Math-500, and AIME2025 also show strong results, with further online A/B testing planned.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for the DeepSeek-R1-W4AFP8 model, which involves a new w4afp8
mixed-precision quantization method. The changes are extensive, including new CUDA kernels for MoE layers, modifications to model configuration, and updates to the quantization framework. The overall implementation looks solid, with new tests for the added functionality.
I've identified a few areas for improvement, mainly related to code clarity and maintainability, such as refactoring duplicated code and improving how different quantization configurations are handled. There is also an incomplete log message that should be fixed. Overall, great work on adding this complex feature.
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized | ||
self.is_checkpoint_w4afp8_serialized = is_checkpoint_w4afp8_serialized | ||
if is_checkpoint_w4afp8_serialized: | ||
logger.warning("Detected w4afp8 checkpoint. Please note that") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This warning message appears to be incomplete. Please complete the sentence to provide a clear and meaningful message to the user.
logger.warning("Detected w4afp8 checkpoint. Please note that") | |
logger.warning("Detected w4afp8 checkpoint. Please note that this is an experimental feature.") |
def run_cutlass_moe_ep_preproess(local_topk_ids: torch.Tensor, local_num_experts: int): | ||
reorder_topk_ids, reorder_ids = torch.sort(local_topk_ids.view(-1), stable=True) | ||
|
||
seg_indptr = torch.zeros( | ||
local_num_experts + 1, device=local_topk_ids.device, dtype=torch.int64 | ||
) | ||
src2dst = torch.empty( | ||
local_topk_ids.numel(), device=local_topk_ids.device, dtype=torch.int32 | ||
) | ||
|
||
BLOCK_SIZE = 512 | ||
grid = (triton.cdiv(local_topk_ids.numel(), BLOCK_SIZE),) | ||
compute_src2dst_triton_kernel[grid]( | ||
reorder_ids, src2dst, local_topk_ids.numel(), BLOCK_SIZE | ||
) | ||
|
||
return reorder_topk_ids, src2dst, seg_indptr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function run_cutlass_moe_ep_preproess
is nearly identical to run_moe_ep_preproess
. This duplication can lead to maintenance issues, where a bug fix or an enhancement in one might be missed in the other.
To improve maintainability, consider refactoring these two functions into a single, more generic function. The core logic is the same, and the different parameter names (local_topk_ids
vs topk_ids
, local_num_experts
vs num_experts
) can be handled by a single implementation.
if self.use_w4afp8: | ||
local_topk_ids = topk_ids | ||
if self.expert_map is not None: | ||
"Translate info from expert_map to topk_ids" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using a string literal as a comment is unconventional and can be confusing, as it might be mistaken for a multi-line docstring if it were triple-quoted. For better clarity and adherence to standard Python style, it's recommended to use a #
for single-line comments.
"Translate info from expert_map to topk_ids" | |
# Translate info from expert_map to topk_ids |
if hasattr(self.quant_config, "activation_scheme"): | ||
assert self.quant_config.activation_scheme == "dynamic" | ||
elif hasattr(self.quant_config, "linear_activation_scheme"): | ||
assert self.quant_config.linear_activation_scheme == "dynamic" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The use of hasattr
to check for different attribute names (activation_scheme
vs. linear_activation_scheme
) across different config classes makes the code harder to read and maintain. This pattern is repeated multiple times in this file.
A cleaner approach would be to define a common property in the base QuantizationConfig
class to abstract this detail. For example:
# In QuantizationConfig or a mixin
@property
def is_static_activation(self) -> bool:
# Default implementation, can be overridden
return getattr(self, 'activation_scheme', 'dynamic') == 'static'
# In W4AFp8Config
@property
def is_static_activation(self) -> bool:
return self.linear_activation_scheme == 'static'
This would simplify the logic here to assert self.quant_config.is_static_activation
, improving clarity and making it easier to add new quantization configs in the future.
Greate work~ |
Are there any bench results available for the Qwen3-moe series? |
Are there any performance data comparisons on grouped gemm between w4af8 and deepgemm fp8? |
@yangsijia-serena Please rebase. Thanks! |
@yangsijia-serena Hi, Congratulation for the great work. I'm trying to re-produce the profiling data on 8*H20, but my profiling data is too bad, the ============ Serving Benchmark Result ============
Backend: sglang
Traffic request rate: 64.0
Max request concurrency: 64
Successful requests: 256
Benchmark duration (s): 1700.29
Total input tokens: 129311
Total generated tokens: 126170
Total generated tokens (retokenized): 125742
Request throughput (req/s): 0.15
Input token throughput (tok/s): 76.05
Output token throughput (tok/s): 74.20
Total token throughput (tok/s): 150.26
Concurrency: 5.16
----------------End-to-End Latency----------------
Mean E2E Latency (ms): 34242.03
Median E2E Latency (ms): 33717.12
---------------Time to First Token----------------
Mean TTFT (ms): 1096.96
Median TTFT (ms): 316.03
P99 TTFT (ms): 4741.74
---------------Inter-Token Latency----------------
Mean ITL (ms): 67.39
Median ITL (ms): 48.76
P95 ITL (ms): 236.50
P99 ITL (ms): 289.68
Max ITL (ms): 4043.97 Hope yours help to analyze the reason. the follows are my environment infos,launch script and log. Thanks sincerely. Did I use the profiling script incorrectly? In addition, a strange phenomenon is that: running time is short, and the gpu utilization is close to full, However, the stage of the environment infos is as followsHardwares:
8*H20 (96GB)
Softwares:
cuda 12.4
sgl-kernel 0.2.1
sglang 0.4.8.post1 /cfs/xtchen/repositories/sglang/python launch and bench scripts# launch script
SGL_ENABLE_JIT_DEEPGEMM=1 python3 -m sglang.launch_server \
--model-path ${deepseek_r1_w4fp8_dir} \
--context-length 8192 \
--tp 8 \
--trust-remote-code \
--host 0.0.0.0 \
--port 8000 \
--mem-fraction-static 0.8 \
--enable-ep-moe \
--cuda-graph-max-bs 256 \
--cuda-graph-bs 1 2 4 8 16 32 64 128 256 \
--max-running-requests 256 \
--disable-radix-cache
# bench script
python3 -m sglang.bench_serving \
--backend sglang \
--base-url http://172.17.97.5:8000 \
--tokenizer /mnt/xtchen/model/DeepSeek-R1-W4AFP8 \
--model /mnt/xtchen/model/DeepSeek-R1-W4AFP8 \
--dataset-name random \
--dataset-path /cfs/xtchen/dataset/ShareGPT_V3_unfiltered_cleaned_split.json \
--random-input-len 1000 \
--random-output 1000 \
--num-prompts 256 \
--request-rate 64 \
--max-concurrency 64 \
--profile \
--output-file online.jsonl the complete log of bench_serve is shown belowbenchmark_args=Namespace(backend='sglang', base_url='http://172.17.97.5:8000', host='0.0.0.0', port=None, dataset_name='random', dataset_path='/cfs/xtchen/dataset/ShareGPT_V3_unfiltered_cleaned_split.json', model='/mnt/xtchen/model/DeepSeek-R1-W4AFP8', tokenizer='/mnt/xtchen/model/DeepSeek-R1-W4AFP8', num_prompts=256, sharegpt_output_len=None, sharegpt_context_len=None, random_input_len=1000, random_output_len=1000, random_range_ratio=0.0, request_rate=64.0, max_concurrency=64, output_file='online.jsonl', output_details=False, disable_tqdm=False, disable_stream=False, return_logprob=False, seed=1, disable_ignore_eos=False, extra_request_body=None, apply_chat_template=False, profile=True, lora_name=None, prompt_suffix='', pd_separated=False, flush_cache=False, warmup_requests=1, tokenize_prompt=False, gsp_num_groups=64, gsp_prompts_per_group=16, gsp_system_prompt_len=2048, gsp_question_len=128, gsp_output_len=256)
Namespace(backend='sglang', base_url='http://172.17.97.5:8000', host='0.0.0.0', port=30000, dataset_name='random', dataset_path='/cfs/xtchen/dataset/ShareGPT_V3_unfiltered_cleaned_split.json', model='/mnt/xtchen/model/DeepSeek-R1-W4AFP8', tokenizer='/mnt/xtchen/model/DeepSeek-R1-W4AFP8', num_prompts=256, sharegpt_output_len=None, sharegpt_context_len=None, random_input_len=1000, random_output_len=1000, random_range_ratio=0.0, request_rate=64.0, max_concurrency=64, output_file='online.jsonl', output_details=False, disable_tqdm=False, disable_stream=False, return_logprob=False, seed=1, disable_ignore_eos=False, extra_request_body=None, apply_chat_template=False, profile=True, lora_name=None, prompt_suffix='', pd_separated=False, flush_cache=False, warmup_requests=1, tokenize_prompt=False, gsp_num_groups=64, gsp_prompts_per_group=16, gsp_system_prompt_len=2048, gsp_question_len=128, gsp_output_len=256)
#Input tokens: 129311
#Output tokens: 126170
Starting warmup with 1 sequences...
Warmup completed with 1 sequences. Starting main benchmark run...
Starting profiler...
Profiler started
100%|█████████████████████████████████████████████████████| 256/256 [02:34<00:00, 1.44it/s]Stopping profiler...
Profiler stopped
100%|█████████████████████████████████████████████████████| 256/256 [28:20<00:00, 6.64s/it]
============ Serving Benchmark Result ============
Backend: sglang
Traffic request rate: 64.0
Max request concurrency: 64
Successful requests: 256
Benchmark duration (s): 1700.29
Total input tokens: 129311
Total generated tokens: 126170
Total generated tokens (retokenized): 125742
Request throughput (req/s): 0.15
Input token throughput (tok/s): 76.05
Output token throughput (tok/s): 74.20
Total token throughput (tok/s): 150.26
Concurrency: 5.16
----------------End-to-End Latency----------------
Mean E2E Latency (ms): 34242.03
Median E2E Latency (ms): 33717.12
---------------Time to First Token----------------
Mean TTFT (ms): 1096.96
Median TTFT (ms): 316.03
P99 TTFT (ms): 4741.74
---------------Inter-Token Latency----------------
Mean ITL (ms): 67.39
Median ITL (ms): 48.76
P95 ITL (ms): 236.50
P99 ITL (ms): 289.68
Max ITL (ms): 4043.97
================================================== part of log on server during profiling as belows[2025-07-06 12:25:30 TP0] Prefill batch. #new-seq: 1, #new-token: 39, #cached-token: 0, #token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-07-06 12:25:31 TP0] Decode batch. #running-req: 1, #token: 0, token usage: 0.00, cuda graph: True, gen throughput (token/s): 1.38, #queue-req: 0
[2025-07-06 12:25:32 TP2] Profiling starts for True. Traces will be saved to: /tmp (with profile id: 1751804732.7859037)
[2025-07-06 12:25:32 TP3] Profiling starts for True. Traces will be saved to: /tmp (with profile id: 1751804732.7859037)
[2025-07-06 12:25:33] INFO: 172.17.97.5:46420 - "POST /start_profile HTTP/1.1" 200 OK
[2025-07-06 12:25:33] INFO: 172.17.97.5:46422 - "POST /generate HTTP/1.1" 200 OK
[2025-07-06 12:25:33 TP0] Prefill batch. #new-seq: 1, #new-token: 39, #cached-token: 0, #token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0
[2025-07-06 12:25:33] INFO: 172.17.97.5:46424 - "POST /generate HTTP/1.1" 200 OK
[2025-07-06 12:25:33] INFO: 172.17.97.5:46426 - "POST /generate HTTP/1.1" 200 OK
...
[2025-07-06 12:28:04 TP0] Decode batch. #running-req: 6, #token: 7554, token usage: 0.02, cuda graph: True, gen throughput (token/s): 249.90, #queue-req: 0
[2025-07-06 12:28:06 TP0] Decode batch. #running-req: 1, #token: 938, token usage: 0.00, cuda graph: True, gen throughput (token/s): 62.30, #queue-req: 0
[2025-07-06 12:28:07 TP6] Stop profiling...
[2025-07-06 12:28:07 TP0] Stop profiling...
[2025-07-06 12:53:53 TP2] Profiling done. Traces are saved to: /tmp
[2025-07-06 12:53:53 TP0] Profiling done. Traces are saved to: /tmp
[2025-07-06 12:53:53] INFO: 172.17.97.5:47284 - "POST /stop_profile HTTP/1.1" 200 OK
[2025-07-06 12:53:53] INFO: 172.17.97.5:50518 - "GET /get_server_info HTTP/1.1" 200 OK
[2025-07-06 12:56:39] INFO: 172.17.97.5:50856 - "POST /generate HTTP/1.1" 200 OK |
Signed-off-by: yangsijia.614 <yangsijia.614@bytedance.com>
96d8889
to
12cd8e5
Compare
Signed-off-by: yangsijia.614 <yangsijia.614@bytedance.com>
12cd8e5
to
f2c7328
Compare
Done. Thanks! |
Hi, have you tried to run the benchmark without profiling? Profiling may affect the performance. |
Signed-off-by: yangsijia.614 <yangsijia.614@bytedance.com>
@yangsijia-serena thanks for the reminder. got the result with the bench script as belows: python3 -m sglang.bench_serving \
--backend sglang \
--base-url http://172.17.97.5:8000 \
--tokenizer /mnt/xtchen/model/DeepSeek-R1-W4AFP8 \
--model /mnt/xtchen/model/DeepSeek-R1-W4AFP8 \
--dataset-name random \
--dataset-path /cfs/xtchen/dataset/ShareGPT_V3_unfiltered_cleaned_split.json \
--random-range-ratio 1.0 \
--random-input-len 1000 \
--random-output 1000 \
--num-prompts 256 \
--request-rate 64 \
--max-concurrency 64 |
…eing set to None. Signed-off-by: yangsijia.614 <yangsijia.614@bytedance.com>
b_strides1, | ||
c_strides1, | ||
s_strides13, | ||
128, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
noticed that chunk_size is hard-coded to 128 here. wondering if only g128 is valid for w4fp8 in your test on hopper arch for now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually we just test w4afp8 on DeepSeek-R1-W4AFP8 model now, where the moe weight is quantized with group_size=128. We can also implement dynamic passing of this value instead of hardcoding it for future flexibility.
quant_method = cls.get_from_keys(config, ["quant_method"]) | ||
is_checkpoint_fp8_serialized = "fp8" in quant_method | ||
is_checkpoint_w4afp8_serialized = "w4afp8" in quant_method | ||
linear_activation_scheme = "dynamic" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just wonder if the quantization scheme is limited for now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The same reason as another comment: the limitation of this quantization scheme is consistent with the DeepSeek-R1-W4AFP8 model. To support other w4a8 models, we may need additional modifications. For instance, the linear layer computation uses the Fp8LinearMethod directly because the quantization for DeepSeek-R1-W4AFP8's linear layer aligns with that of DeepSeek-R1. If we encounter another model with a w4a8 moe layer but a different linear layer quantization, further adjustments will be necessary to accommodate it.
@yangsijia-serena @zhyncs add documentation providing usage guidance for DeepSeek-R1-W4AFP8? |
Hi, thanks for the great work. I noticed there is only R1-W4AFp8 on huggingface (instead of V3-0324 nor R1-0528), could you please share the guidance how to quant arbitrary model into that format (e.g. did you use TensorRT-Model-Optimizer?) |
ok, will do~ |
You can refer to https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/models/core/deepseek_v3/README.md#w4afp8 |
Thanks for the information! |
Very excellent work! However, when we tried to reproduce the results, we encountered some accuracy issues, which are quite different from those shown in the PR. |
Hi @pengyao96 , We previously tested the accuracy using evalscope. |
…7762) Signed-off-by: yangsijia.614 <yangsijia.614@bytedance.com>
…7762) Signed-off-by: yangsijia.614 <yangsijia.614@bytedance.com>
Motivation
This PR supports running DeepSeek-R1-W4AFP8 model with ep-moe mode(deepep mode support is on the way~)
Due to the reduced space required for model weights and decreased bandwidth usage, DeepSeek R1 models can now be run on a single H2O or H100, leading to improved throughput and latency.
Usage:
Run without mtp:
Run with mtp:
Note: DRAFT_MODEL can be exported using export_deepseek_nextn.py script
Benchmark
Performance:
We run DeepSeek-R1-W4AFP8 on 8*H20 with ep8, comparing to run DeepSeek-R1 on 16*H20 with ep16.
Test configuration: input/output len = 1000/1000, qps=64, max_concurrency=64, num_prompt=256.
The results are shown below:
DeepSeek-R1-W4AFP8 on 8*H20 with ep8
DeepSeek-R1 on 16*H20 with ep16
We can see there is obvious improvement on both throughput and latency using DeepSeek-R1-W4AFP8.
Accuracy:
We have evaluated the model accuracy on some typical benchmark, the result:
mmlu: 90.82
Math-500: 94.6
AIME2025: 66.7
We will do online A/B test for further verification.
Model Info
DeepSeek-R1-W4AFP8 is a mixed-precision quantized DeepSeek-R1, with dense layer using FP8_BLOCK_SCALING, MoE layers using INT4 weights and FP8 activation.
Modifications
Architecture Overview
Key Components
W4AFp8Config
Define the quantization methods for the model, encompassing both weight and activation quantization. During the inference process, the
W4AFp8Config
is responsible for selecting the appropriate QuantizationMethods based on the layer type (e.g., Linear, Attention, MoE, etc.).W4AFp8MoEMethod
Encapsulates the quantization logic for the W4AFp8 MoE layer. It mainly includes two core methods:
create_weights
: Defines and initializes the weights and scale parameters for the W4AFp8 MoE layer. During model loading, it will parse safetensors based on these parameters.process_weights_after_loading
: Performs post-processing on the weights after the model weights and scale have been loaded into the layer object, converting them into the structure and type required for inference.Cutlass W4A8 MoE
The specific implementation class for the W4AFp8 MoE computation process: performing scale operations on the input hidden states and intermediate values, invoking triton kernels and cutlass kernels to complete the w4a8 grouped gemm operations, and so on.
Kernel
Completes the kernel portion of the MoE computation. Both triton kernels and cutlass kernels are used, with their usages and selection reasons as follows:
Workflow
The workflow for running DeepSeek-R1-W4AFP8 can be outlined as follows:
Checklist