Skip to content

Conversation

pavanimajety
Copy link
Collaborator

@pavanimajety pavanimajety commented Jun 4, 2025

Motivation

Adds Model support for DeepSeek R1 FP4 Model. (Functional enablement - kernel optimizations underway)

Modifications

Adds ModelOptFP4FusedMoEMethod and CutlassMoEParams dataclass to initialize the parameters required by Cutlass MoE methods.

Checklist

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @pavanimajety, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

Summary of Changes

Gemini here, providing a summary for this pull request. This PR introduces support for the DeepSeek R1 FP4 model. The core changes involve implementing the FP4 Mixture-of-Experts (MoE) layer using Cutlass kernels. To facilitate this, a new dataclass CutlassMoEParams has been added to encapsulate the various parameters required by the Cutlass MoE operations. Additionally, a new quantization method, ModelOptNvFp4FusedMoEMethod, is introduced to handle the specific weight loading, processing, and application logic for these FP4 MoE layers within the model optimization framework.

Highlights

  • Cutlass MoE Parameters Dataclass: Introduces CutlassMoEParams to consolidate and manage parameters (strides, problem sizes, offsets, etc.) for Cutlass-based MoE operations, simplifying function signatures.
  • FP4 Fused MoE Quantization Method: Adds ModelOptNvFp4FusedMoEMethod to handle the specific requirements of FP4 quantized MoE layers, including weight processing and integration with the Cutlass kernel.
  • Refactored Cutlass MoE Function: The cutlass_moe_fp4 function is updated to accept the new CutlassMoEParams object, making its interface cleaner.

Changelog

Click here to see the changelog
  • python/sglang/srt/layers/moe/cutlass_moe.py
    • Imported CutlassMoEParams.
    • Renamed cutlass_fused_experts to cutlass_fused_experts_fp8 to clarify its FP8 specific nature.
    • Modified the cutlass_moe_fp4 function signature to accept a CutlassMoEParams object.
    • Updated the implementation of cutlass_moe_fp4 to use parameters from the CutlassMoEParams object.
    • Adjusted assertions and the final output calculation in cutlass_moe_fp4.
  • python/sglang/srt/layers/moe/cutlass_moe_params.py
    • Added a new file defining CutlassMoEType enum and CutlassMoEParams dataclass.
    • CutlassMoEParams stores various parameters needed for Cutlass MoE operations.
    • Includes methods to_gemm1_args and to_gemm2_args to prepare arguments for the two GEMM stages.
  • python/sglang/srt/layers/quantization/fp8.py
    • Updated import and function call to use the renamed cutlass_fused_experts_fp8.
  • python/sglang/srt/layers/quantization/modelopt_quant.py
    • Imported necessary classes and functions related to FP4 MoE and layer exclusion.
    • Added a check and warning for missing quantization config values.
    • Introduced is_layer_excluded method for regex-based module exclusion.
    • Modified get_quant_method to use is_layer_excluded and return UnquantizedLinearMethod for excluded layers.
    • Added the ModelOptNvFp4FusedMoEMethod class to handle FP4 MoE quantization.
    • Implemented create_weights, swizzle_blockscale, process_weights_after_loading, and apply methods within ModelOptNvFp4FusedMoEMethod.
  • python/sglang/srt/models/deepseek_v2.py
    • Updated a log message to indicate FP4 support for shared experts fusion optimization.
  • python/sglang/test/test_cutlass_moe.py
    • Updated imports and function calls to use the renamed cutlass_fused_experts_fp8.
  • sgl-kernel/python/sgl_kernel/moe.py
    • Imported Dict and Any.
    • Modified the cutlass_fp4_group_mm function signature to accept a single params dictionary.
    • Updated the function body to access parameters from the params dictionary.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.


Four bits for speed,
Experts spread wide,
Cutlass kernels gleam,
MoE's swift tide.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for FP4 quantization for Mixture of Experts (MoE) layers, specifically targeting nvidia/DeepSeekR1-FP4 model compatibility. The changes include the addition of CutlassMoEParams to encapsulate MoE parameters and ModelOptNvFp4FusedMoEMethod to handle the FP4 quantization logic for FusedMoE layers. The refactoring of parameters is a good improvement for clarity and maintainability. The core logic for FP4 MoE seems well-structured. However, there are a few points that need attention, particularly a removed assertion that could impact correctness.

Summary of Findings

  • Missing Newline at End of File: File python/sglang/srt/layers/moe/cutlass_moe_params.py is missing a newline character at the end of the file. This is a minor PEP 8 style issue.
  • Commented-out Parameters in CutlassMoEParams: In python/sglang/srt/layers/moe/cutlass_moe_params.py, the methods to_gemm1_args and to_gemm2_args have commented-out pointer parameters (e.g., a_ptrs, b_ptrs). If these are not currently used by cutlass_fp4_group_mm or other FP4 paths, they should be removed for clarity or their purpose (e.g., future use for FP8) should be documented.
  • Unused Import: In python/sglang/srt/layers/quantization/modelopt_quant.py, within the ModelOptNvFp4FusedMoEMethod.apply method (line 722), fused_experts is imported from sglang.srt.layers.moe.fused_moe_triton.fused_moe but does not appear to be used in this method. This import could be removed.
  • Local Import of regex: In python/sglang/srt/layers/quantization/modelopt_quant.py, the regex module is imported inside the is_layer_excluded method (line 300). PEP 8 generally recommends imports be at the top of the file. If regex is a standard dependency, consider moving the import. If it's optional or specific to this function to avoid circular dependencies or for other reasons, a comment explaining this might be helpful.

Merge Readiness

This pull request makes significant strides in adding FP4 MoE support. The code is generally well-structured. However, there is a high-severity concern regarding a removed assertion in cutlass_moe.py that needs to be addressed to ensure correctness. Additionally, a medium-severity issue regarding device handling in swizzle_blockscale should be reviewed. I recommend addressing these points before merging. I am not authorized to approve pull requests, so please ensure further review and approval from other maintainers.

padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
cols // 4, 4)
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
swizzled_scale = swizzled_scale.contiguous().cuda()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In the swizzle_blockscale method, swizzled_scale is explicitly moved to CUDA using .cuda().

Is this explicit move necessary? Typically, tensors are managed on the device specified by the layer or model (e.g., self.quant_config.device). If scale (the input to this method) is already on the correct device, .cuda() might be redundant. If the device could be something other than CUDA, this would cause an error.

Consider using .to(self.quant_config.device) or ensuring scale is already on the intended device to make the device handling more robust and explicit. For example:

# Assuming self.quant_config.device holds the target device
swizzled_scale = swizzled_scale.contiguous().to(self.quant_config.device)
        swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
        swizzled_scale = swizzled_scale.contiguous().to(scale.device) # Or self.quant_config.device if appropriate

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
@pavanimajety
Copy link
Collaborator Author

Please note that this PR functionally enables FP4 Checkpoints. The optimizations for cutlass_fp4_moe through Flashinfer are underway and will have more configs for better end to end performance of nvidia/DeepSeekR1-FP4 model.

@pavanimajety
Copy link
Collaborator Author

model test -

 python -m sglang.launch_server --model-path nvidia/DeepSeek-R1-FP4 --tp 4 --quantization modelopt_fp4 --kv-cache-dtype=auto

Result:

root@gb-nvl-054-compute01:/workspace/scratch-pmaj-1/gh-pm-sglang# curl http://127.0.0.1:30000/v1/completions   -H "Content-Type: application/json"   -d '{
    "model": "default",
    "prompt": "The President of United States is known to be ",
    "temperature": 0
  }'
{"id":"836ab52ebd8e45b19faf95d2cc6e3855","object":"text_completion","created":1749095546,"model":"default","choices":[{"index":0,"text":"1 of the most powerful person in the world. The President is the head of","logprobs":null,"finish_reason":"length","matched_stop":null}],"usage":{"prompt_tokens":11,"total_tokens":27,"completion_tokens":16,"prompt_tokens_details":null}}

@pavanimajety pavanimajety marked this pull request as ready for review June 5, 2025 03:53
@pavanimajety pavanimajety changed the title [Draft][DeepseekR1-FP4] Add Support for nvidia/DeepSeekR1-FP4 model [DeepseekR1-FP4] Add Support for nvidia/DeepSeekR1-FP4 model Jun 5, 2025
@pavanimajety
Copy link
Collaborator Author

#6887 needs to be merged for the CI to look clean

@zhyncs zhyncs self-assigned this Jun 5, 2025
@pavanimajety
Copy link
Collaborator Author

@zhyncs Checking again - all the failures seem unrelated. Is there anything else needed?

return ModelOptFp4LinearMethod(self)
if self.kv_cache_quant_algo and isinstance(layer, RadixAttention):
return ModelOptFp8KVCacheMethod(self)

elif isinstance(layer, FusedMoE):
return ModelOptNvFp4FusedMoEMethod(self)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can make it more robust to cover FP8 fused moe as well.

@pavanimajety
Copy link
Collaborator Author

GSM-8K Evaluation is |0.9530|± |0.0058|

lm_eval --model sglang \                                                                                                                                                                                                                                                                                                                                                                    
    --model_args pretrained=nvidia/DeepSeek-R1-0528-FP4,tp_size=4,kv_cache_dtype="fp8_e4m3",max_model_len=32768,add_bos_token=True,quantization="modelopt_fp4" \                                                                                                                                                                                                                            
    --tasks gsm8k \                                                                                                                                                                                                                                                                                                                                                                      
    --batch_size 256 --num_fewshot 5 

Copy link
Member

@zhyncs zhyncs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unblock first

@zhyncs zhyncs merged commit c2c4f57 into sgl-project:main Jun 8, 2025
1 of 42 checks passed
jianan-gu pushed a commit to jianan-gu/sglang that referenced this pull request Jun 12, 2025
@pyc96 pyc96 mentioned this pull request Jun 21, 2025
6 tasks
walker-ai pushed a commit to walker-ai/sglang that referenced this pull request Jul 8, 2025
Merge branch 'sgl_20250610_sync_tag047 of git@code.alipay.com:Theta/SGLang.git into main

https://code.alipay.com/Theta/SGLang/pull_requests/52


Reviewed-by: 剑川 <jianchuan.gys@antgroup.com>


* [Bugfix] Fix slice operation when chunk size mismatch (sgl-project#6697)
* [Bugfix] Fix ChatCompletion endpoint of mini_lb when stream is set (sgl-project#6703)
* [CI] Fix setup of disaggregation with different tp (sgl-project#6706)
* [PD] Remove Unnecessary Exception Handling for FastQueue.get() (sgl-project#6712)
* Fuse routed_scaling_factor in DeepSeek (sgl-project#6710)
* Overlap two kernels in DeepSeek with communication (sgl-project#6711)
* Minor refactor two-batch overlap (sgl-project#6682)
* Speed up when having padding tokens two-batch overlap (sgl-project#6668)
* [Feature] Support Flashinfer fp8 blockwise GEMM kernel on Blackwell (sgl-project#6479)
* Fix LoRA bench (sgl-project#6719)
* temp
* Fix PP for Qwen3 MoE (sgl-project#6709)
* [feat] triton kernel for get_last_loc (sgl-project#6676)
* [fix] more mem for draft_extend cuda_graph (sgl-project#6726)
* [PD] bug fix:  Update status if nixl receiver send a a dummy req. (sgl-project#6720)
* Tune memory arguments on B200 (sgl-project#6718)
* Add DeepSeek-R1-0528 function call chat template (sgl-project#6725)
* refactor(tool call): Fix BaseFormatDetector tool_index issue and refactor `parse_streaming_increment` (sgl-project#6715)
* Add draft extend CUDA graph for Triton backend (sgl-project#6705)
* refactor apply_w8a8_block_fp8_linear in fp (sgl-project#6545)
* [PD] Support completion endpoint (sgl-project#6729)
* PD Rust LB (PO2) (sgl-project#6437)
* Super tiny enable sole usage of expert distribution metrics and update doc (sgl-project#6680)
* Support picking variants of EPLB algorithms (sgl-project#6728)
* Support tuning DeepEP configs (sgl-project#6742)
* [test] add ut and bm for get_last_loc (sgl-project#6746)
* Fix mem_fraction_static for AMD CI (sgl-project#6748)
* [fix][RL] Fix DeepSeekV3ForCausalLM.post_load_weights for multiple update weight (sgl-project#6265)
* Improve EPLB logical to physical dispatch map (sgl-project#6727)
* Update DeepSeek-R1-0528 function call chat template (sgl-project#6765)
* [PD] Optimize time out logic and add env var doc for mooncake (sgl-project#6761)
* Fix aiohttp 'Chunk too big' in bench_serving (sgl-project#6737)
* Support sliding window in triton backend (sgl-project#6509)
* Fix shared experts fusion error (sgl-project#6289)
* Fix one bug in the grouped-gemm triton kernel (sgl-project#6772)
* update llama4 chat template and pythonic parser (sgl-project#6679)
* feat(tool call): Enhance Llama32Detector for improved JSON parsing in non-stream (sgl-project#6784)
* Support token-level quantization for EP MoE (sgl-project#6782)
* Temporarily lower mmlu threshold for triton sliding window backend (sgl-project#6785)
* ci: relax test_function_call_required (sgl-project#6786)
* Add intel_amx backend for Radix Attention for CPU (sgl-project#6408)
* Fix incorrect LoRA weight loading for fused gate_up_proj (sgl-project#6734)
* fix(PD-disaggregation): Can not get local ip (sgl-project#6792)
* [FIX] mmmu bench serving result display error (sgl-project#6525) (sgl-project#6791)
* Bump torch to 2.7.0 (sgl-project#6788)
* chore: bump sgl-kernel v0.1.5 (sgl-project#6794)
* Improve profiler and integrate profiler in bench_one_batch_server (sgl-project#6787)
* chore: upgrade sgl-kernel v0.1.5 (sgl-project#6795)
* [Minor] Always append newline after image token when parsing chat message (sgl-project#6797)
* Update CI tests for Llama4 models (sgl-project#6421)
* [Feat] Enable PDL automatically on Hopper architecture (sgl-project#5981)
* chore: update blackwell docker (sgl-project#6800)
* misc: cache is_hopper_arch (sgl-project#6799)
* Remove contiguous before Flashinfer groupwise fp8 gemm (sgl-project#6804)
* Correctly abort the failed grammar requests & Improve the handling of abort (sgl-project#6803)
* [EP] Add cuda kernel for moe_ep_pre_reorder (sgl-project#6699)
* Add draft extend CUDA graph for flashinfer backend  (sgl-project#6805)
* Refactor CustomOp to avoid confusing bugs (sgl-project#5382)
* Tiny log prefill time (sgl-project#6780)
* Tiny fix EPLB assertion about rebalancing period and recorder window size (sgl-project#6813)
* Add simple utility to dump tensors for debugging (sgl-project#6815)
* Fix profiles do not have consistent names (sgl-project#6811)
* Speed up rebalancing when using non-static dispatch algorithms (sgl-project#6812)
* [1/2] Add Kernel support for Cutlass based Fused FP4 MoE (sgl-project#6093)
* [Router] Fix k8s Service Discovery (sgl-project#6766)
* Add CPU optimized kernels for topk and rope fusions  (sgl-project#6456)
* fix new_page_count_next_decode (sgl-project#6671)
* Fix wrong weight reference in dynamic EPLB (sgl-project#6818)
* Minor add metrics to expert location updater (sgl-project#6816)
* [Refactor] Rename `n_share_experts_fusion` as `num_fused_shared_experts` (sgl-project#6735)
* [FEAT] Add transformers backend support  (sgl-project#5929)
* [fix] recover auto-dispatch for rmsnorm and rope (sgl-project#6745)
* fix ep_moe_reorder kernel bugs (sgl-project#6858)
* [Refactor] Multimodal data processing for VLM (sgl-project#6659)
* Decoder-only Scoring API (sgl-project#6460)
* feat: add dp-rank to KV events (sgl-project#6852)
* Set `num_fused_shared_experts` as `num_shared_experts` when shared_experts fusion is not disabled (sgl-project#6736)
* Fix one missing arg in DeepEP (sgl-project#6878)
* Support LoRA in TestOpenAIVisionServer and fix fused kv_proj loading bug. (sgl-project#6861)
* support 1 shot allreduce  in 1-node and 2-node using mscclpp (sgl-project#6277)
* Fix Qwen3MoE missing token padding optimization (sgl-project#6820)
* Tiny update error hints (sgl-project#6846)
* Support layerwise rebalancing experts (sgl-project#6851)
* Tiny allow profiler API to auto create directory (sgl-project#6865)
* Support Blackwell DeepEP docker images (sgl-project#6868)
* [EP] Add cuda kernel for moe_ep_post_reorder (sgl-project#6837)
* [theta]merge 0605
* oai: fix openAI client error with single request via batch api (sgl-project#6170)
* [PD] Fix potential perf spike caused by tracker gc and optimize doc (sgl-project#6764)
* Use deepgemm instead of triton for fused_qkv_a_proj_with_mqa (sgl-project#6890)
* [CUTLASS-FP4-MOE]  Introduce CutlassMoEParams class for easy initialization of Cutlass Grouped Gems Metadata (sgl-project#6887)
* bugfix(OAI): Fix image_data processing for jinja chat templates (sgl-project#6877)
* [CPU] enable CI for PRs, add Dockerfile and auto build task (sgl-project#6458)
* AITER backend extension and workload optimizations (sgl-project#6838)
* [theta]merge
* [theta]merge
* [Feature] Support Flashinfer fmha on Blackwell (sgl-project#6930)
* Fix a bug in abort & Improve docstrings for abort (sgl-project#6931)
* Tiny support customize DeepEP max dispatch tokens per rank (sgl-project#6934)
* Sync the changes on cuda graph runners (sgl-project#6932)
* [PD] Optimize transfer queue forward logic for dummy rank (sgl-project#6922)
* [Refactor] image data process in bench_serving (sgl-project#6879)
* [fix] logical_to_all_physical_map index 256 is out of bounds in EP parallel. (sgl-project#6767)
* Add triton fused moe kernel config for E=257 on B200 (sgl-project#6939)
* [sgl-kernel] update deepgemm (sgl-project#6942)
* chore: bump sgl-kernel v0.1.6 (sgl-project#6943)
* Minor compile fused topk (sgl-project#6944)
* [Bugfix] pipeline parallelism and Eagle Qwen2 (sgl-project#6910)
* Tiny re-introduce profile id logging (sgl-project#6912)
* Add triton version as a fused_moe_triton config search key to avoid performace decrease in different Triton version (sgl-project#5955)
* reduce torch.zeros overhead in moe align block size kernel (sgl-project#6369)
* chore: upgrade sgl-kernel v0.1.6 (sgl-project#6945)
* add fbgemm moe grouped gemm kernel benchmark (sgl-project#6924)
* [Docker] Add docker file for SGL Router (sgl-project#6915)
* Disabling mixed chunked prefill when eagle is enabled (sgl-project#6874)
* Add canary for EPLB rebalancing (sgl-project#6895)
* Refactor global_server_args_dict (sgl-project#6866)
* Fuse routed scaling factor in topk_reduce kernel (sgl-project#6220)
* Update server timeout time in AMD CI. (sgl-project#6953)
* [misc] add is_cpu() (sgl-project#6950)
* Add H20 fused MoE kernel tuning configs for DeepSeek-R1/V3 (sgl-project#6885)
* Add a CUDA kernel for fusing mapping and weighted sum for MoE. (sgl-project#6916)
* chore: bump sgl-kernel v0.1.6.post1 (sgl-project#6955)
* chore: upgrade sgl-kernel v0.1.6.post1 (sgl-project#6957)
* [DeepseekR1-FP4] Add Support for nvidia/DeepSeekR1-FP4 model (sgl-project#6853)
* Revert "Fuse routed scaling factor in topk_reduce kernel (sgl-project#6220)" (sgl-project#6968)
* [AMD] Add more tests to per-commit-amd (sgl-project#6926)
* chore: bump sgl-kernel v0.1.7 (sgl-project#6963)
* Slightly improve the sampler to skip unnecessary steps (sgl-project#6956)
* rebase h20 fused_moe config (sgl-project#6966)
* Fix CI and triton moe Configs (sgl-project#6974)
* Remove unnecessary kernels of num_token_non_padded (sgl-project#6965)
* Extend cuda graph capture bs for B200 (sgl-project#6937)
* Fuse routed scaling factor in deepseek (sgl-project#6970)
* Sync cuda graph runners (sgl-project#6976)
* Fix draft extend ut stability with flush cache (sgl-project#6979)
* Fix triton sliding window test case (sgl-project#6981)
* Fix expert distribution dumping causes OOM (sgl-project#6967)
* Minor remove one kernel for DeepSeek (sgl-project#6977)
* [perf][sgl-kernel] extend cutlass_mla_decode to support num_head < 128 (sgl-project#6929)
* Enable more unit tests for AMD CI. (sgl-project#6983)
* Use torch.compile to fuse flash attention decode metadata preparation (sgl-project#6973)
* Eliminate stream sync to speed up LoRA batch init  (sgl-project#6960)
* support qwen3 emebedding (sgl-project#6990)
* Fix torch profiler bugs for bench_offline_throughput.py (sgl-project#6557)
* chore: upgrade flashinfer v0.2.6.post1 jit (sgl-project#6958)
* cleanup tmp dir (sgl-project#7007)
* chore: update pr test xeon (sgl-project#7008)
* Fix cutlass MLA gets almost zero accuracy (sgl-project#6998)
* Update amd nightly models CI. (sgl-project#6992)
* feat: add direct routing strategy to DP worker (sgl-project#6884)
* Fallback to lower triton version for unfound fused moe configs (sgl-project#7013)
* Fix torchvision version for Blackwell (sgl-project#7015)
* Simplify prepare_extend_after_decode (sgl-project#6987)
* Migrate to assertEqual (sgl-project#6741)
* Fix torch version in blackwell dockerfile (sgl-project#7017)
* chore: update pr test xeon (sgl-project#7018)
* Update default settings for blackwell (sgl-project#7023)
* Support both approximate and exact expert distribution collection (sgl-project#6964)
* Add decode req pool (sgl-project#6980)
* [theta]merge 0610
* [theta]merge 0610
* [CI] Add CI workflow for sgl-router docker build (sgl-project#7027)
* Fix fused_moe triton configs (sgl-project#7029)
* CPU: map changes from developing branch in sgl-kernel (sgl-project#6833)
* chore: bump v0.4.7 (sgl-project#7038)
* Update README.md (sgl-project#7040)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants