Skip to content

Conversation

vhain
Copy link
Contributor

@vhain vhain commented Mar 27, 2025

Motivation

I've been testing Gemma 3 and it's fine-tunes with SGLang for couple days, and noticed weird behavior where it starts to generate garbage tokens as soon as context length (input + generated) starts to exceed somewhere around 2k.

and seems like few others are experiencing the same thing.

I've ran the same prompt through vLLM and it did not generate garbage tokens.

Modifications

This PR implements method get_attention_sliding_window_size in Gemma3ForConditionalGeneration and Gemma3ForCausalLM, so that attention backends can be initialized with proper sliding window size.

get_attention_sliding_window_size implementation was brought from Gemma 2:

# Aligned with HF's implementation, using sliding window inclusive with the last token
# SGLang assumes exclusive
def get_attention_sliding_window_size(config):
return config.sliding_window - 1

Investigations

After thorough investigation (trust me, it took the whole night for me as I'm not that familiar) including deeper comparison with vLLM's Gemma implementations and SGLang's Gemma implementations (1, 2, and 3), I have found following.

Current Gemma3Attention initializes RadixAttention with sliding_window_size:

self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
logit_cap=getattr(self.config, "attn_logit_softcapping", None),
sliding_window_size=self.sliding_window,
prefix=add_prefix("attn", prefix),
)

but it seems like it does nothing but to store it as property only:

sliding_window_size: int = -1,

In fact, the actual backend gets initialized alongside with ForwardBatch, and it uses model_runner.sliding_window_size:

if model_runner.sliding_window_size is not None:

and this value gets set by looking at get_attention_sliding_window_size of the model:

# Parse other args
self.sliding_window_size = (
self.model.get_attention_sliding_window_size()
if hasattr(self.model, "get_attention_sliding_window_size")
else None
)

Checklist

Future Works

  • It was kinda confusing that RadixAttention takes sliding_window_size as it's parameter but the model actually needed to have get_attention_sliding_window_size defined.
  • The updated (fixed) Gemma 3 implementation in SGLang still seems to generate a bit little different than vLLM. Although I am not sure which one is the correct one ([Model] Add support for Gemma 3 vllm-project/vllm#14660 points out that vLLM's implementation could have less accuracy). It will be nice to compare generation result with the official implementation.

Inference Comparison

Started server with:

$ python3 -m python.sglang.launch_server --model-path google/gemma-3-4b-it --context-length 8192

Ran inference with:

const res = await fetch(`http://localhost:${process.argv[2] || '30000'}/v1/chat/completions`, {
  method: "POST",
  headers: {
    "Content-Type": "application/json",
  },
  body: JSON.stringify({
    model: "google/gemma-3-4b-it",
    max_tokens: 4096,
    // temperature: 1,
    top_p: .95,
    // frequency_penalty: 0,
    // presence_penalty: 0,
    // repetition_penalty: 1,
    top_k: 64,
    // min_p: 0,
    messages: [
      {
        role: "system",
        content:
          "Read the introduction of an individual as provided. Please translate it into Korean.",
      },
      {
        role: "user",
        // following content is machine inferenced by random commercial LLM
        content: `The first thing most people notice about Dr. Maya Chen is her laugh—a rich, unrestrained sound that starts somewhere deep in her chest and erupts with surprising force from her petite frame. It's a laugh that seems at odds with her carefully cultivated professional image: the immaculate lab coat, the sensible shoes, the dark hair always pulled back in a practical bun. But that juxtaposition is perfectly Maya—a woman of seeming contradictions who has spent a lifetime defying expectations and carving her own path through the world.

Born in 1982 to Chinese immigrants in a small town outside Portland, Oregon, Maya grew up straddling two worlds. At home, her parents—her father a civil engineer and her mother a former concert pianist—maintained many of their cultural traditions, speaking Mandarin, celebrating lunar new year, and instilling in Maya a deep respect for education and familial duty. Outside the home, Maya navigated the complexities of being one of the few Asian-American students in her school, developing a chameleon-like ability to adapt to different social situations while maintaining her core sense of self....[TRUNCATED]`
      },
    ],
    stream: true,
  }),
});

try {
  for await (const event of streamCompletions(res)) {
    const { role, content } = event.choices?.[0]?.delta || {};
    if (role) {
      process.stdout.write(`${role}: `);
    }
    if (content) {
      process.stdout.write(`${content}`);
    }
  }
} finally {
  process.stdout.write("\n");
}

on main branch (f60f293):

Okay, yet—

Read more than…

***

Read more »

***

Read more »

Read more »

Read

Read
...[TRUNCATED]

on this branch (7bc9bfb):

Okay, here's a Korean translation of the provided introduction to Dr. Maya Chen, aiming for a balance of accuracy and natural-sounding Korean:

**마야 천 박사의 소개**

대부분의 사람들은 먼저 그녀의 웃음소리를 알아차립니다. 그것은 가슴에서부터 우러나오는 풍부하고, 억제되지 않은 소리입니다. 그녀의 작은 체구에서 뿜어져 나오는 놀라운 힘으로 시작되는 웃음이죠. 이 웃음소리는 그녀가 정성 들여 쌓아 올린 전문적인 이미지와 어울리지 않는 것처럼 보입니다. 완벽한 실험복, 편안한 신발, 항상 실용적인 똥머리로 묶어 kept는 검은 머리 말입니다. 하지만 이러한 대비는 마야를 완벽하게 정의합니다. 그녀는 예상을 깨고 자신만의 길을 개척해 온, 겉으로는 모순되는 듯한 여성입니다.

1982년 오리건주 포틀랜드 외곽의 작은 마을에서 중국 이민자 가정에서 태어난 마야는 두 세계 사이에서 자랐습니다. 집에서는 그녀의 부모님, 아버지께서는 시공사로서, 어머니께서는 은퇴한 무대 의상가로서, 많은 문화적 전통을 유지하며, 마야에게 교육과 가족의 의무에 대한 깊은 존경심을 심어주셨습니다. 집 밖에서는 마야는 학교에서 유일한 아시아-미국 학생으로서 다양한 사회적 상황에 적응하는 데 능숙해졌습니다. 동시에 자신의 핵심적인 자아를 유지하면서 말입니다.

“저는 코드를 바꾸는 법을 알아차리기 훨씬 전에 유창하게 구사하게 되었어요.” 마야는 종종 그 특유의 웃음소리와 함께 말합니다. “저는 학교에서 수학과 과학을 excel하는 마야, 중국 커뮤니티 노인들에게 의료 문서 번역을 도와주는 마야, 피아노 레슨 시간에 책상 밑에서 비밀리에 SF 소설을 읽는 마야였습니다. 왈츠 위트만이라고 한다면 저는 다중성을 품고 있었습니다.”

이러한 다양한 정체성은 마야가 교육을 받을 때 도움이 되었습니다. 먼저 MIT에서 생화학과 비교 문학을 복수 전공했고, (또 다른 모순되는 것 같지만 그녀에게는 말이 안 됐어요) 이후 스탠포드 대학에서 분자 생물학 박사 과정을 이수했고, 마지막으로 존스  Hopkins에서 의학 학위를 취득했습니다. 그 과정에서 그녀는 궁극적으로 그녀의 삶의 기반이 될, 전통 중국 의학과 최첨단 유전체 연구의 융합이라는 특이한 전문 분야를 개발했습니다.

“제 할머니는 광주에서 전통적인 치유사로 불렸어요.” 마야는 설명합니다. “어릴 적에는 그녀가 약초 요법을 준비하는 것을 보았어요. 그녀의 굳은 손가락은 엄청난 정밀함과 확신으로 움직였죠. 그녀는 이러한 조합이 작동하는 이유에 대한 분자 메커니즘을 설명할 수 없었습니다. 단순히 그것이 효과가 있다는 것을 알고 있었어요. 수세기 동안 관찰과 실천을 통해 말입니다. 서구 과학 교육을 계속하면서 저는 이러한 고대 요법과 현대 의학 사이의 잠재적인 연결고리에 매료되었습니다.”

이러한 매력으로 인해 마야는 2018년에 전 세계의 전통 요법을 조사하는 엄격한 과학적 방법을 적용하는 선구적인 기관인 Integrative Genomic Medicine Research Center를 설립했습니다. 그녀의 지도 아래 센터는 여러 문화권의 전통 약초에서 신경 퇴행성 질환 치료에 도움이 될 수 있는 몇 가지 화합물을 식별했습니다. 이 작업은 논란의 여지가 없지 않았습니다. 전통주의자들로부터 고대의 치유 관행을 분자 구성 요소로 줄이는 것에 대해 비판을 받았고, 일부 과학계에서는 그것을 허위 과학으로 간주하여 신뢰성을 더하는 것에 대해 비판을 받았습니다.

마야는 이러한 비판 모두에 침착하게 대처합니다. “과학은 우리가 우리의 가정에 의문을 제기할 때 발전합니다.” 그녀는 말합니다. “전통적인 치유 시스템과 현대 의학 모두는 눈먼 곳이 있습니다. 제 작업은 그들 사이의 공간에 있습니다.”
...[TRUNCATED]

@zhyncs zhyncs merged commit 0bc0bf5 into sgl-project:main Mar 27, 2025
1 of 19 checks passed
@vhain vhain deleted the ryan/gemma3/attention-sliding-window-conf branch March 27, 2025 19:00
@atbe
Copy link

atbe commented Mar 27, 2025

thank you for fixing this!

@merrymercy
Copy link
Contributor

@vhain This seems to break the CI https://github.com/sgl-project/sglang/actions/runs/14119078538/job/39555810283#step:4:221. Can you take a look?

@vhain
Copy link
Contributor Author

vhain commented Mar 28, 2025

@merrymercy Seems like the error is CUDA OOM? This PR did not directly introduced any code to allocate CUDA memory except for configuring sliding window for attention backend. Maybe attention backend (in this case flashinfer) requires slightly more memory when configured with bigger sliding window?

Perhaps we just need to decrease --mem-fraction-static in

other_args=[
"--trust-remote-code",
"--chat-template",
"gemma-it",
],

@vhain
Copy link
Contributor Author

vhain commented Mar 28, 2025

@merrymercy let me pull up a PR with lower --mem-fraction-static for Gemma 3 vision test and see if it passes the CI.

@zhyncs
Copy link
Member

zhyncs commented Mar 28, 2025

@merrymercy let me pull up a PR with lower --mem-fraction-static for Gemma 3 vision test and see if it passes the CI.

0.75 is ok @vhain @merrymercy

jimoosciuc pushed a commit to Furion-cn/sglang that referenced this pull request Apr 17, 2025
pi314ever pushed a commit to pi314ever/sglang that referenced this pull request Apr 23, 2025
* Fix ut mla-test-1-gpu-amd (sgl-project#4813)

Co-authored-by: Zhang Kaihong <zhangkaihong.zkh@alibaba-inc.com>

* Remove Unintended Capture Batch Sizes in AMD HIP Graph Runner (sgl-project#4638)

* [k8s] Clarified the usage of shared memory. (sgl-project#4341)

* gemma3: impl `get_attention_sliding_window_size` for attn init (sgl-project#4823)

* add partial_json_parser and einops (sgl-project#4827)

* fix the release doc dependency issue (sgl-project#4828)

* Update doc for DeepSeek-V3-0324 (sgl-project#4825)

* deps: lazy import optional dependencies `gguf` and `torchvision` (sgl-project#4826)

* Update MMMU Benchmark instructions (sgl-project#4694)

* Fix the nightly eval by lowering the threshold of `neuralmagic/gemma-2-2b-it-FP8` (sgl-project#4830)

* Basic Cleanup (sgl-project#4833)

* Support (1 <= dp < tp) in the dp attention in DeepEP (sgl-project#4770)

Co-authored-by: Cheng Wan <cwan39@gatech.edu>

* [Fix] Add compressed_tensors as deps (sgl-project#4819)

* Fix error due to CustomAllreduce setup failure (sgl-project#4815)

Signed-off-by: Kebe <mail@kebe7jun.com>

* use default for torch.ops (sgl-project#4835)

* [CI] Remove unused imports with Ruff to pre-commit config, only to benchmarks/docs/examples folder (sgl-project#3969)

* [Misc] Fix issues reported by torchfix (sgl-project#4837)

* Include context length in /v1/models response. (sgl-project#4809)

* [Fix] `self.worker` assignment in `TpModelWorker` and refactor references (sgl-project#4788)

Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>

* Fix the lora adapter when lora path is none (sgl-project#4799)

Co-authored-by: Beichen Ma <mabeichen12@gmail.com>

* fix: fix typo of comments in w8a8_fp8.py (sgl-project#4843)

* Remove retry in nightly tests (sgl-project#4846)

* Fix CI of test_patch_torch (sgl-project#4844)

* IPv6 support (sgl-project#3949)

Signed-off-by: Brayden Zhong <b8zhong@uwaterloo.ca>

* ci: add condition for daily docker build (sgl-project#4487)

* [Fix] fix output_top_logprobs is not exist (sgl-project#4597)

* fix: when use SGLANG_PORT this env,port is str (sgl-project#4528)

Signed-off-by: rongfu.leng <lenronfu@gmail.com>

* Support Page Size > 1 for FA3 (sgl-project#4832)

Co-authored-by: Qingquan Song <ustcsqq@gmail.com>
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>

* Fix Engine error when enabling DP attention (sgl-project#4648)

* fix: Inappropriate lack of Optional type on OpenAI ChatCompletionRequest (sgl-project#4681)

* Support controlling nsys start and end range programmatically (sgl-project#4688)

* Remove empty tool function name (sgl-project#4704)

Signed-off-by: Kebe <mail@kebe7jun.com>

* Fix missing arguments in SchedulePolicy and RadixCache initialization in tests. (sgl-project#4712)

* get the python version from env (sgl-project#4729)

* Fix torch.cuda.MemPool() internal assertion failure (sgl-project#4687)

Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>

* Super tiny remove unused code (sgl-project#4750)

* Support with_stack and record_shapes in profiler (sgl-project#4740)

Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>

* test: reduce `mem_fraction_static` for gemma3 vision test (sgl-project#4840)

* Fix CI tests (sgl-project#4853)

* Fix fa3 cuda graph page_size > 1 precision and page_size=1 speed (sgl-project#4855)

* Revert "get the python version from env (sgl-project#4729)" (sgl-project#4863)

* [Feature] add multi-rank support for Lora (sgl-project#4492)

Co-authored-by: rudy152 <czh1137892874@gmail.com>

* Clean up `import vllm` in quantization/__init__.py (sgl-project#4834)

* Fix wrong variable name when stopping memory profile (sgl-project#4772)

* [Feat] support deepgemm for cmake (sgl-project#4864)

* Make torch compile configurable for biased_grouped_topk (sgl-project#4749)

* update sgl-kernel test ci (sgl-project#4866)

* fix sampling issue (sgl-project#4871)

* bump sgl-kernel 0.0.5.post4 (sgl-project#4768)

* fix sgl-kernel cu118 build (sgl-project#4872)

* [Feature] Support FA3 backend for MLA (sgl-project#4831)

* upgrade sgl-kernel 0.0.5.post4 (sgl-project#4873)

* update torch compile doc (sgl-project#4874)

* bump v0.4.4.post3 (sgl-project#4878)

* Fix BadRequestError wrong arguments and remove openai dependency (sgl-project#4882)

* Improve stack trace of retry errors (sgl-project#4845)

* Tiny fix doc error (sgl-project#4795)

* [Docs] Update DeepGEMM at README.md (sgl-project#4886)

* Update CODEOWNERS (sgl-project#4889)

* Delete test_deep_gemm.py (sgl-project#4891)

* Add deepseek style fused moe group gate selection kernel (sgl-project#4530)

* quick fix: add default for new kernel (sgl-project#4898)

* remove setup for sgl-kernel (sgl-project#4899)

* [Misc] Clean m.def and add Development Tips (sgl-project#4890)

* fix allreduce test (sgl-project#4909)

* Support page size > 1 + eagle (sgl-project#4908)

* Fix retract for page size > 1 (sgl-project#4914)

* [Feature] use pytest for sgl-kernel (sgl-project#4896)

* fix bmm fp8 (sgl-project#4926)

* Fix the timeout for unit-test-2-gpu in pr-test.yml (sgl-project#4927)

* Fix 2-gpu CI test and suppress some warnings (sgl-project#4930)

* [feat] add fa3 in sgl-kernel (sgl-project#4902)

Co-authored-by: Sleepcoo <Sleepcoo@gmail.com>

* Fix sglang frontend's incorrect dependency on torch (sgl-project#4931)

* [Fix] avoid stream sync and torch compile in prefill for fa3 backend (sgl-project#4932)

* cleanup sgl-kernel (sgl-project#4933)

* [Fix] Improve Lora tests and reduce CI runtime (sgl-project#4925)

* Fix DeepSeek bug causing 2.2% MMLU drop when TP!=DP (sgl-project#4883)

Co-authored-by: ch-wan <cwan39@gatech.edu>

* [Fix] Add torch compile for torch.clamp back (sgl-project#4936)

* Fix oom error for large page size (sgl-project#4913)

Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>

* [feat] interface for platforms abstraction (sgl-project#4928)

* [Fix] revert clean m.def for cudagraph (sgl-project#4944)

* refactor: multimodal data (sgl-project#4754)

* bump sgl-kernel v0.0.6 (sgl-project#4950)

* [Build] Fix cuda12.8 build error in nvfp4_scaled_mm_kernels.cu (sgl-project#4953)

* use fa3 in sgl-kernel (sgl-project#4954)

* Revert PR 4764 & 4813 related to R1 RoPE (sgl-project#4959)

* [Feature] Support DeepEP Low Latency (sgl-project#4767)

Co-authored-by: sleepcoo <sleepcoo@gmail.com>
Co-authored-by: laixinn <xielx@shanghaitech.edu.cn>
Co-authored-by: ch-wan <cwan39@gatech.edu>

* update bench_serving (sgl-project#4958)

* Prevent memory leak of retract_decode when page_size > 1 (sgl-project#4977)

* [VLM RLHF] Take Image input for verl vlm rollout (sgl-project#4915)

Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
Co-authored-by: GeLee <leege233@gmail.com>

* Large page size aligned hierarchical caching (sgl-project#4581)

* bug fix for hicache host eviction (sgl-project#4989)

* sgl scaled_fp8_quant support output padding (sgl-project#4861)

* Add Eagle Speculative Decoding to FA3 Backend (sgl-project#4951)

Co-authored-by: hebiao064 <hebiaobuaa@gmail.com>
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
Co-authored-by: zcnrex <zcnrex@gmail.com>

* Update tokenizer_manager.py (sgl-project#5008)

* [sgl-kernel] per token group quant support COLUMN MAJOR (sgl-project#4817)

* update cutlass tag (sgl-project#5011)

* Feature/revise docs ci (sgl-project#5009)

* fix: fix illegal cuda memory access at fused_moe_kernel (sgl-project#4727)

Co-authored-by: yuethe <yuethe@tencent.com>

* [Build] Support build sgl-kernel with ccache (sgl-project#5020)

* fix deepgemm as well (sgl-project#5030)

* try to fix ci oserror (sgl-project#5024)

* Replace enable_flashinfer_mla argument with attention_backend (sgl-project#5005)

* Small refactor DeepEPMode to clean up code a bit (sgl-project#4992)

* [Fix] fix fa3 build at cu118 (sgl-project#5036)

* Revert "Replace enable_flashinfer_mla argument with attention_backend" (sgl-project#5048)

* bump sgl-kernel v0.0.7 (sgl-project#5046)

* update eagle-3 docs (sgl-project#4796)

Co-authored-by: Yifan Zhang <zhangyif21@mails.tsinghua.edu.cn>

* Add LlavaLlamaForCausaLM in MultiModal Processors (sgl-project#5039)

Co-authored-by: Ravi Theja Desetty <ravitheja@Ravis-MacBook-Pro.local>

* Update the retry count (sgl-project#5051)

* upgrade sgl-kernel v0.0.7 (sgl-project#5049)

* [2/3] fix dsv3 awq issue  (sgl-project#4625)

Co-authored-by: 晟海 <huangtingwei.htw@antgroup.com>
Co-authored-by: laixinn <xielx@shanghaitech.edu.cn>

* Feature/revise docs ci (sgl-project#5056)

* Add H20 fused MoE kernel tuning configs for DeepSeek V3/R1 (sgl-project#5057)

* [fix] remove `cuda_device_count_stateless` (sgl-project#5060)

* Small refactor DeepEPDispatcher into subclasses (sgl-project#4994)

* Support async DeepEP by splitting into two stages (sgl-project#4995)

* Cleanup unused resources after DeepEP operation (sgl-project#4996)

* Add DeepSeek V3/R1 shared experts fusion (sgl-project#4918)

* [deepep] fix: shared experts are not initialized when shared experts fusion is enabled (sgl-project#5072)

* fix dummy-load deepseekv2 (sgl-project#4535)

* support sgl-kernel on blackwell (sgl-project#5074)

* FA3 Spec Decoding to support top k = 1 and add cuda graph support (sgl-project#5050)

Co-authored-by: Qingquan Song <ustcsqq@gmail.com>
Co-authored-by: Chunan Zeng <zcnrex@gmail.com>

* [Revision] Replace enable_flashinfer_mla argument with attention_backend (sgl-project#5052)

* upgrade transformers 4.51.0 (sgl-project#5088)

* sgl-kernel transfer custom allreduce from trt kernel to vllm kernel (sgl-project#5079)

* bump sgl-kernel 0.0.8 (sgl-project#5089)

* python transfer custom allreduce from trt kernel to vllm kernel (sgl-project#5080)

* bump v0.4.4.post4 (sgl-project#5091)

* Fix: Reduce the number of document ci attempts to avoid long ci running (sgl-project#5097)

Co-authored-by: shuaills <shishuaiuoe@gmail.com>

* Add Llama4 support (sgl-project#5092)

Co-authored-by: Cheng Wan <cwan39@gatech.edu>
Co-authored-by: fzyzcjy <ch271828n@outlook.com>
Co-authored-by: ispobock <ispobaoke@163.com>

* Fix refactor error - fp8.py (sgl-project#5106)

Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>

* bump v0.4.5 (sgl-project#5117)

* Workaround for async copy issue in HPU eager mode (sgl-project#1)

Signed-off-by: Rahul Vijayaraghavan <rvijayaraghavan@habana.ai>
Co-authored-by: Rahul Vijayaraghavan <rvijayaraghavan@habana.ai>

* [SW-223847]: Fix sgl_kernel module not available (sgl-project#2)

Co-authored-by: vikram singh shekhawat <vshekhawat@habana.ai>

* [Base] Enable torch compile (sgl-project#4)

* [SW-226331] disable dynamic shape in torch compile mode

Signed-off-by: Mohit Sinha <msinha@habana.ai>

---------

Signed-off-by: Kebe <mail@kebe7jun.com>
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
Signed-off-by: Brayden Zhong <b8zhong@uwaterloo.ca>
Signed-off-by: rongfu.leng <lenronfu@gmail.com>
Signed-off-by: Rahul Vijayaraghavan <rvijayaraghavan@habana.ai>
Signed-off-by: Mohit Sinha <msinha@habana.ai>
Co-authored-by: strgrb <zhangkaihong.zkh@antgroup.com>
Co-authored-by: Zhang Kaihong <zhangkaihong.zkh@alibaba-inc.com>
Co-authored-by: AinL <gmlwns5176@gmail.com>
Co-authored-by: Jiří Suchomel <jiri.suchomel@statsperform.com>
Co-authored-by: Juwan Yoo <ryan@tmfi.us>
Co-authored-by: Yineng Zhang <me@zhyncs.com>
Co-authored-by: Ke Bao <ISPObaoke@163.com>
Co-authored-by: Ravi Theja <ravi03071991@gmail.com>
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
Co-authored-by: Daniel Holanda <holand.daniel@gmail.com>
Co-authored-by: tarinkk <129432511+tarinkk@users.noreply.github.com>
Co-authored-by: Cheng Wan <cwan39@gatech.edu>
Co-authored-by: Junrong Lin <33685709+ocss884@users.noreply.github.com>
Co-authored-by: Kebe <mail@kebe7jun.com>
Co-authored-by: Brayden Zhong <b8zhong@uwaterloo.ca>
Co-authored-by: Jon Durbin <jon@jondurbin.com>
Co-authored-by: XinyuanTong <115166877+JustinTong0323@users.noreply.github.com>
Co-authored-by: Qiaolin Yu <qy254@cornell.edu>
Co-authored-by: Beichen Ma <mabeichen12@gmail.com>
Co-authored-by: Jiaqi <57028284+ZhuJiaqi9905@users.noreply.github.com>
Co-authored-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
Co-authored-by: Vincent <vincentzhongy+githubvincent4@gmail.com>
Co-authored-by: warjiang <1096409085@qq.com>
Co-authored-by: lambert0312 <lambert80.ios@gmail.com>
Co-authored-by: rongfu.leng <lenronfu@gmail.com>
Co-authored-by: Stefan He <hebiaobuaa@gmail.com>
Co-authored-by: Qingquan Song <ustcsqq@gmail.com>
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
Co-authored-by: BroadbentJim <BroadbentJim@users.noreply.github.com>
Co-authored-by: vikram singh shekhawat <vshekhawat@habana.ai>
Co-authored-by: DavidChan <chengwei0519@163.com>
Co-authored-by: chaobo jia <91889375+jcbjcbjc@users.noreply.github.com>
Co-authored-by: rudy152 <czh1137892874@gmail.com>
Co-authored-by: Fr4nk1in <sh.fu@outlook.com>
Co-authored-by: yinfan98 <1106310035@qq.com>
Co-authored-by: Yi Zhang <1109276519@qq.com>
Co-authored-by: Adarsh Shirawalmath <114558126+adarshxs@users.noreply.github.com>
Co-authored-by: Sleepcoo <Sleepcoo@gmail.com>
Co-authored-by: SEPLOS <seplos@aliyun.com>
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
Co-authored-by: JieXin Liang <Alcanderian@users.noreply.github.com>
Co-authored-by: Mick <mickjagger19@icloud.com>
Co-authored-by: Yuhong Guo <yuhong.gyh@antgroup.com>
Co-authored-by: Jinyan Chen <93358689+liz-badada@users.noreply.github.com>
Co-authored-by: laixinn <xielx@shanghaitech.edu.cn>
Co-authored-by: GeLee <leege233@gmail.com>
Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
Co-authored-by: zcnrex <zcnrex@gmail.com>
Co-authored-by: Kaiyu Yang <yangky@umich.edu>
Co-authored-by: renxin <90580890+renxinx@users.noreply.github.com>
Co-authored-by: saltyfish66 <38240284+saltyfish66@users.noreply.github.com>
Co-authored-by: yuethe <yuethe@tencent.com>
Co-authored-by: simveit <69345428+simveit@users.noreply.github.com>
Co-authored-by: Yifan Zhang <zhangyif21@mails.tsinghua.edu.cn>
Co-authored-by: Ravi Theja Desetty <ravitheja@Ravis-MacBook-Pro.local>
Co-authored-by: AniZpZ <zhuangsen.zp@antgroup.com>
Co-authored-by: 晟海 <huangtingwei.htw@antgroup.com>
Co-authored-by: Tommy Yang <tommyyang0524@gmail.com>
Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com>
Co-authored-by: inkcherry <mingzhi.liu@intel.com>
Co-authored-by: mlmz <54172054+minleminzui@users.noreply.github.com>
Co-authored-by: shuaills <shishuaiuoe@gmail.com>
Co-authored-by: Chang Su <chang.s.su@oracle.com>
Co-authored-by: fzyzcjy <ch271828n@outlook.com>
Co-authored-by: HAI <hixiao@gmail.com>
Co-authored-by: Rahul Vijayaraghavan <rahul.vijayaraghavan@intel.com>
Co-authored-by: Rahul Vijayaraghavan <rvijayaraghavan@habana.ai>
Co-authored-by: Jay Thakur <jthakur@habana.ai>
Co-authored-by: Anshuman Tripathy <atripathy@habana.ai>
@gabinguo
Copy link

gabinguo commented May 7, 2025

related issue #6099

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants