Skip to content

Conversation

HaiShaw
Copy link
Collaborator

@HaiShaw HaiShaw commented Feb 4, 2025

Motivation

  1. Enable sgl-kernel on ROCm, make it easy to add individual kernels
  2. Use sgl_moe_align_block kernel for fused_moe for performance

Modifications

As they are.

Checklist

if is_cuda:
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size

from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add is_cuda or is_hip here, we shouldn't disrupt other devices (such as xpu, hpu, etc.).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make though this one (code followed previous), I will add that in follow up, sounds good?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok


sources = [
"src/sgl-kernel/torch_extension_rocm.cc",
"src/sgl-kernel/csrc/moe_align_kernel.cu",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to update this for ROCm?

#define DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not this moment, works now. We will do deeper custom later.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I just remember to remove the ROCm-related code earlier.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhyncs created 2 issues to work on asap, thx.

@HaiShaw HaiShaw mentioned this pull request Feb 4, 2025
2 tasks
@zhyncs zhyncs merged commit 2c1a695 into sgl-project:main Feb 4, 2025
18 of 19 checks passed
timethink pushed a commit to timethink/sglang that referenced this pull request Mar 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants