Skip to content

Conversation

zhyncs
Copy link
Member

@zhyncs zhyncs commented Jan 2, 2025

Motivation

# enable Triton implementation
export ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON=1

batch size 1/8/32, input/output 128/256, around 20% improvement for online cases

TODO @BBuf will continue to optimize the CUDA version.

Prefill. latency: 2.03397 s, throughput:     62.93 token/s
Decode.  latency: 2.15392 s, throughput:      0.46 token/s
Decode.  latency: 0.02824 s, throughput:     35.41 token/s
Decode.  latency: 0.02772 s, throughput:     36.07 token/s
Decode.  latency: 0.02776 s, throughput:     36.02 token/s
Decode.  latency: 0.02775 s, throughput:     36.03 token/s
Decode.  median latency: 0.02776 s, median throughput:     36.02 token/s
Total. latency:  4.355 s, throughput:     31.23 token/s
Benchmark ...
Prefill. latency: 0.14870 s, throughput:    860.77 token/s
Decode.  latency: 0.02796 s, throughput:     35.76 token/s
Decode.  latency: 0.02771 s, throughput:     36.09 token/s
Decode.  latency: 0.02771 s, throughput:     36.09 token/s
Decode.  latency: 0.02771 s, throughput:     36.09 token/s
Decode.  latency: 0.02772 s, throughput:     36.07 token/s
Decode.  median latency: 0.02796 s, median throughput:     35.76 token/s
Total. latency:  7.268 s, throughput:     52.84 token/s

Prefill. latency: 5.71712 s, throughput:    179.11 token/s
Decode.  latency: 2.32011 s, throughput:      3.45 token/s
Decode.  latency: 0.03317 s, throughput:    241.20 token/s
Decode.  latency: 0.03296 s, throughput:    242.74 token/s
Decode.  latency: 0.03353 s, throughput:    238.57 token/s
Decode.  latency: 0.03384 s, throughput:    236.42 token/s
Decode.  median latency: 0.03384 s, median throughput:    236.42 token/s
Total. latency:  8.239 s, throughput:    132.06 token/s
Benchmark ...
Prefill. latency: 0.18112 s, throughput:   5653.82 token/s
Decode.  latency: 0.03253 s, throughput:    245.91 token/s
Decode.  latency: 0.03272 s, throughput:    244.47 token/s
Decode.  latency: 0.03296 s, throughput:    242.74 token/s
Decode.  latency: 0.03349 s, throughput:    238.89 token/s
Decode.  latency: 0.03379 s, throughput:    236.73 token/s
Decode.  median latency: 0.03415 s, median throughput:    234.25 token/s
Total. latency:  8.888 s, throughput:    345.62 token/s

Prefill. latency: 7.90844 s, throughput:    517.93 token/s
Decode.  latency: 1.95407 s, throughput:     16.38 token/s
Decode.  latency: 0.03557 s, throughput:    899.60 token/s
Decode.  latency: 0.03624 s, throughput:    882.95 token/s
Decode.  latency: 0.03713 s, throughput:    861.73 token/s
Decode.  latency: 0.03770 s, throughput:    848.71 token/s
Decode.  median latency: 0.03770 s, median throughput:    848.71 token/s
Total. latency: 10.087 s, throughput:    431.44 token/s
Benchmark ...
Prefill. latency: 0.30875 s, throughput:  13266.22 token/s
Decode.  latency: 0.03496 s, throughput:    915.27 token/s
Decode.  latency: 0.03531 s, throughput:    906.18 token/s
Decode.  latency: 0.03622 s, throughput:    883.53 token/s
Decode.  latency: 0.03698 s, throughput:    865.34 token/s
Decode.  latency: 0.03760 s, throughput:    851.03 token/s
Decode.  median latency: 0.04098 s, median throughput:    780.82 token/s
Total. latency: 10.729 s, throughput:   1145.35 token/s

Modifications

Checklist

  • Format your code according to the Contributor Guide.
  • Add unit tests as outlined in the Contributor Guide.
  • Update documentation as needed, including docstrings or example tutorials.

Co-authored-by: WANDY666 <1060304770@qq.com>
@zhyncs zhyncs requested a review from BBuf January 2, 2025 12:44
@zhyncs zhyncs marked this pull request as draft January 2, 2025 12:52
@zhyncs zhyncs marked this pull request as ready for review January 2, 2025 13:06
tokens_cnts,
num_experts,
numel,
tokens_per_thread,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we can adjust num_warps and num_stages for better performance. num_stages=2 may work for these kernels.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion. Kernel tuning will be done in a follow-up.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After setting num_warps=4 and num_stages=2, I ran an end-to-end test on bench_one_latency. There was little to no acceleration in the process. Once I have more bandwidth, I'll study how to improve the kernel in detail.

@zhyncs
Copy link
Member Author

zhyncs commented Jan 2, 2025

ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON=1 python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code

python3 benchmark/gsm8k/bench_sglang.py --num-questions 2000 --parallel 2000 --num-shots 8
Accuracy: 0.951
Invalid: 0.000
Latency: 126.074 s
Output throughput: 1101.262 token/s

@zhyncs
Copy link
Member Author

zhyncs commented Jan 2, 2025

Online processing achieved a 20% speedup, but offline throughput decreased. As a result, this feature is disabled by default on NVIDIA GPUs and must be enabled via an environment variable ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON. It is recommended to enable it for latency-sensitive use cases. Kudos to @WANDY666

@zhyncs zhyncs merged commit ba5112f into main Jan 2, 2025
17 checks passed
@zhyncs zhyncs deleted the zhyncs/amd branch January 2, 2025 13:47
XiaotongJiang pushed a commit to XiaotongJiang/sglang that referenced this pull request Jan 3, 2025
Co-authored-by: WANDY666 <1060304770@qq.com>
timethink pushed a commit to timethink/sglang that referenced this pull request Mar 9, 2025
Co-authored-by: WANDY666 <1060304770@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants