-
Notifications
You must be signed in to change notification settings - Fork 2.8k
feat: support moe_align_block_size_triton #2712
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Co-authored-by: WANDY666 <1060304770@qq.com>
tokens_cnts, | ||
num_experts, | ||
numel, | ||
tokens_per_thread, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps we can adjust num_warps and num_stages for better performance. num_stages=2 may work for these kernels.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good suggestion. Kernel tuning will be done in a follow-up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After setting num_warps=4 and num_stages=2, I ran an end-to-end test on bench_one_latency. There was little to no acceleration in the process. Once I have more bandwidth, I'll study how to improve the kernel in detail.
|
Online processing achieved a 20% speedup, but offline throughput decreased. As a result, this feature is disabled by default on NVIDIA GPUs and must be enabled via an environment variable |
Co-authored-by: WANDY666 <1060304770@qq.com>
Co-authored-by: WANDY666 <1060304770@qq.com>
Motivation
batch size 1/8/32, input/output 128/256, around 20% improvement for online cases
TODO @BBuf will continue to optimize the CUDA version.
Modifications
Checklist