Skip to content

Conversation

aiyiwang2025
Copy link
Contributor

@aiyiwang2025 aiyiwang2025 commented Jun 26, 2025

Description

Currently, the Hunyuan inference team supports the Hunyuan-A13B model. By adding the hunyuan_v1.py related files, it supports the two models of HunYuanDenseV1ForCausalLM and HunYuanMoEV1ForCausalLM.

We have validated the accuracy of this PR,HunYuan (new MoE LLM model from Tencent) will open source these days.

Thanks~

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @aiyiwang2025, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces comprehensive support for Tencent's HunYuan V1 models, encompassing both dense and Mixture-of-Experts (MoE) variants, into the vLLM framework. This integration enables efficient inference for these new models by incorporating their specific architectural components, including specialized rotary embeddings and advanced attention mechanisms, and by registering them within the vLLM ecosystem.

Highlights

  • New Model Definition: A new module hunyuan_v1.py has been added, defining the core components for Tencent's HunYuan models. This includes their specific MLP, attention mechanisms (supporting self and cross-attention), decoder layers, and the overall model structure, enabling vLLM to support these new architectures.
  • Mixture-of-Experts (MoE) Support: The pull request introduces the HunYuanSparseMoeBlock to specifically handle the Mixture-of-Experts (MoE) architecture found in HunYuanMoEV1. This includes logic for expert routing and shared MLP components, crucial for efficient inference of MoE models.
  • Dynamic NTK Alpha Rotary Embedding: A new rotary embedding class, DynamicNTKAlphaRotaryEmbedding, has been implemented and integrated into the get_rope function. This addition supports HunYuan's unique dynamic NTK alpha scaling for positional embeddings, ensuring accurate model behavior.
  • Cross-Layer Attention (CLA) Handling: The HunYuan model architecture now includes specific logic to manage and utilize prev_kv_states for Cross-Layer Attention (CLA), which is a feature potentially used in HunYuan models to enhance performance or context understanding.
  • Model Registration: Both HunYuanDenseV1ForCausalLM and HunYuanMoEV1ForCausalLM models are now registered within the vLLM model registry, making them discoverable and loadable by the framework for inference.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces support for Tencent's HunYuanMoEV1 model within the vLLM framework. It includes the addition of a new rotary embedding and a model definition file, along with modifications to the model registry. The changes appear to be well-structured and validated, but there are a few opportunities to improve code clarity and remove redundant code.

def _compute_cos_sin_cache(self) -> torch.Tensor:
# For Hunyuan DynamicNTKAlphaRotaryEmbedding
max_len = self.max_position_embeddings
base = self.base * self.scaling_alpha ** (self.rotary_dim / (self.rotary_dim-2))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider adding parentheses to improve readability and explicitly define the order of operations in this calculation. It's a good practice to make the intent clear, especially when dealing with multiple operators1.

Style Guide References

Suggested change
base = self.base * self.scaling_alpha ** (self.rotary_dim / (self.rotary_dim-2))
base = self.base * (self.scaling_alpha ** (self.rotary_dim / (self.rotary_dim-2)))

Footnotes

  1. While there is no explicit style guide provided, using parentheses for clarity in complex calculations is a common practice to improve code readability.

Comment on lines 343 to 344
k_tmp = torch.empty_like(k) # Todo: reduant rotary embedding
q, _ = self.rotary_emb(positions, q, k_tmp)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This rotary embedding is marked as redundant. It should be removed to avoid unnecessary computation and memory allocation. If it's truly needed in some cases, clarify the conditions under which it's used, otherwise, remove it entirely. Leaving dead code can lead to confusion and maintenance issues1.

            k_tmp = torch.empty_like(k)  # Todo: reduant rotary embedding
            q, _ = self.rotary_emb(positions, q, k_tmp)

Style Guide References

Footnotes

  1. While there is no explicit style guide provided, avoiding dead code is a common practice to improve code maintainability.

@DarkLight1337
Copy link
Member

DarkLight1337 commented Jun 26, 2025

Thanks for implementing your model in vLLM! Can you add it to the list of supported models and update the tests as well?

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@aiyiwang2025 aiyiwang2025 requested a review from hmellor as a code owner June 26, 2025 07:33
@mergify mergify bot added the documentation Improvements or additions to documentation label Jun 26, 2025
@DarkLight1337 DarkLight1337 requested a review from Isotr0py June 26, 2025 07:40
Copy link
Member

@Isotr0py Isotr0py left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just add some initial comments with a glance, will take a deeper look later, PTAL!

Comment on lines 576 to 583
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}

This field isn't used anymore.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, remove useless parts

Comment on lines 215 to 216
class HunYuanAttention(nn.Module):

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest to decouple self-attn and cross-attn implementation into HunYuanAttention and HunYuanCrossAttention respectively.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, HunYuanMoEV1ForCausalLM does not use cross-attn, but self-attn, so keep HunYuanAttention

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although cross-attn is unused, I think decouple this can improve the readability, otherwise this attention layer implementation is a little bit too long to read.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Decouple self-attn and cross-attn implementation into HunYuanAttention and HunYuanCrossAttention respectively.

@@ -388,6 +388,7 @@ Specified using `--task generate`.
| `MiniMaxM1ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-M1-40k`, `MiniMaxAI/MiniMax-M1-80k`etc. | | | |
| `MiniMaxText01ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-Text-01`, etc. | | | |
| `Zamba2ForCausalLM` | Zamba2 | `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. | | | |
| `HunYuanMoEV1ForCausalLM` | Hunyuan-80B-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`etc. | | | ✅︎ |
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we haven't supported cross attention in v1 yet, does this model work with v1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because it is self-attn, it currently supports v1 and has been verified

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also update the dense model in document? And seems that PP should also support too?

@@ -0,0 +1,851 @@
# coding=utf-8
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# coding=utf-8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# coding=utf-8

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
Copy link
Collaborator

@jeejeelee jeejeelee Jun 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no need to explicitly set embedding_modules and embedding_padding_modules.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, remove useless parts

@aiyiwang2025 aiyiwang2025 force-pushed the hunyuan_a13b branch 2 times, most recently from 864f1bc to 06c60c3 Compare June 26, 2025 10:05
Comment on lines 126 to 127
"HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanV1ForCausalLM"),
"HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanV1ForCausalLM"),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we have better decouple the dense model and MoE model implementation, because MoE model implementation has a more complicated weight loading logic, which makes maintenance difficult with couplings.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Decouple the dense model and MoE model implementation

Comment on lines 215 to 216
class HunYuanAttention(nn.Module):

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although cross-attn is unused, I think decouple this can improve the readability, otherwise this attention layer implementation is a little bit too long to read.

Co-authored-by: quinnrong <quinnrong@tencent.com>
Signed-off-by: aiyiwang <aiyiwang@tencent.com>
@xjpang
Copy link

xjpang commented Jun 28, 2025

Throw exception when load Hunyuan-A13B-Instruct-FP8 model. @aiyiwang2025

^[[1;36m(VllmWorker rank=2 pid=31732)^[[0;0m ERROR 06-28 22:10:15 [multiproc_executor.py:487] self.mlp = HunYuanSparseMoeBlock(^M ^[[1;36m(VllmWorker rank=2 pid=31732)^[[0;0m ERROR 06-28 22:10:15 [multiproc_executor.py:487] File "/data/miniconda3/envs/env-3.10/lib/python3.10/site-packages/vllm/model_executor/models/hunyuan_v1_moe.py", line 154, in __init__^M ^[[1;36m(VllmWorker rank=2 pid=31732)^[[0;0m ERROR 06-28 22:10:15 [multiproc_executor.py:487] self.experts = FusedMoE(^M ^[[1;36m(VllmWorker rank=2 pid=31732)^[[0;0m ERROR 06-28 22:10:15 [multiproc_executor.py:487] File "/data/miniconda3/envs/env-3.10/lib/python3.10/site-packages/vllm/model_executor/layers/fused_moe/layer.py", line 866, in __init__^M ^[[1;36m(VllmWorker rank=2 pid=31732)^[[0;0m ERROR 06-28 22:10:15 [multiproc_executor.py:487] raise ValueError("Duplicate layer name: {}".format(prefix))^M ^[[1;36m(VllmWorker rank=2 pid=31732)^[[0;0m ERROR 06-28 22:10:15 [multiproc_executor.py:487] ValueError: Duplicate layer name:

@intervitens
Copy link
Contributor

intervitens commented Jun 28, 2025

I also had the error that @xjpang got. Patched it by adding prefix to experts and gate modules

git diff
diff --git a/vllm/model_executor/models/hunyuan_v1_moe.py b/vllm/model_executor/models/hunyuan_v1_moe.py
index 1262434a8..54177acdf 100644
--- a/vllm/model_executor/models/hunyuan_v1_moe.py
+++ b/vllm/model_executor/models/hunyuan_v1_moe.py
@@ -124,6 +124,7 @@ class HunYuanSparseMoeBlock(nn.Module):
         config: PretrainedConfig,
         quant_config: Optional[QuantizationConfig] = None,
         layer_id: int = -1,
+        prefix: str = "",
     ):
         super().__init__()
         self.tp_size = get_tensor_model_parallel_world_size()
@@ -159,10 +160,15 @@ class HunYuanSparseMoeBlock(nn.Module):
             reduce_results=False,
             renormalize=True if top_k > 1 else False,
             quant_config=quant_config,
+            prefix=f"{prefix}.experts",
         )
 
         self.gate = ReplicatedLinear(
-            config.hidden_size, config.num_experts, bias=False, quant_config=None
+            config.hidden_size,
+            config.num_experts,
+            bias=False,
+            quant_config=None,
+            prefix=f"{prefix}.gate",
         )
         if config.use_mixed_mlp_moe > 0:
             # Get layer_id num_shared_expert if config.num_shared_expert is a list
@@ -517,6 +523,7 @@ class HunYuanDecoderLayer(nn.Module):
             config=config,
             quant_config=quant_config,
             layer_id=layer_id,
+            prefix=f"{prefix}.mlp",
         )
         self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
         self.post_attention_layernorm = RMSNorm(

@xjpang
Copy link

xjpang commented Jun 29, 2025

I also had the error that @xjpang got. Patched it by adding prefix to experts and gate modules

git diff

got it。 Thanks.

Signed-off-by: aiyiwang <aiyiwang@tencent.com>
@aiyiwang2025
Copy link
Contributor Author

I also had the error that @xjpang got. Patched it by adding prefix to experts and gate modules
git diff

got it。 Thanks.

@xjpang In fact, I did not encounter the above problems. You can try the corresponding solutions @intervitens . I have also synchronized the corresponding code to this PR

@xjpang
Copy link

xjpang commented Jun 29, 2025

@aiyiwang2025 Can you provide reasoning and function calling parser?

@aiyiwang2025
Copy link
Contributor Author

aiyiwang2025 commented Jun 29, 2025

@aiyiwang2025 Can you provide reasoning and function calling parser?

@xjpang reasoing parser is currently under development

tool call parser can refer to this

return final_hidden_states.view(orig_shape)


class HunYuanAttention(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can reuse attention and mlp implementation from hunyuan_v1_dense.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Remove duplicate logic in hunyuan_v1_moe.py

Comment on lines 391 to 399
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=AttentionType.ENCODER_DECODER,
)

We need to pass AttentionType.ENCODER_DECODER for cross attn.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 460 to 461
attention_type = ("cross" if layer_id >= 0
and layer_id % cla_factor != 0 else "self")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
attention_type = ("cross" if layer_id >= 0
and layer_id % cla_factor != 0 else "self")
attention_type = (AttentionType.ENCODER_DECODER if layer_id >= 0
and layer_id % cla_factor != 0 else AttentionType.DECODER)

We can use vLLM's AttentionType enum here:

class AttentionType:
"""
Attention type.
Use string to be compatible with `torch.compile`.
"""
# Decoder attention between previous layer Q/K/V
DECODER = "decoder"
# Encoder attention between previous layer Q/K/V for encoder-decoder
ENCODER = "encoder"
# Encoder attention between previous layer Q/K/V
ENCODER_ONLY = "encoder_only"
# Attention between dec. Q and enc. K/V for encoder-decoder
ENCODER_DECODER = "encoder_decoder"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 709 to 710
if "mlp.experts" in name:
continue
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if "mlp.experts" in name:
continue

Seems unnecessary for dense model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -388,6 +388,7 @@ Specified using `--task generate`.
| `MiniMaxM1ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-M1-40k`, `MiniMaxAI/MiniMax-M1-80k`etc. | | | |
| `MiniMaxText01ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-Text-01`, etc. | | | |
| `Zamba2ForCausalLM` | Zamba2 | `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. | | | |
| `HunYuanMoEV1ForCausalLM` | Hunyuan-80B-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`etc. | | | ✅︎ |
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also update the dense model in document? And seems that PP should also support too?

@@ -259,6 +259,7 @@ def check_available_online(
"Zamba2ForCausalLM": _HfExamplesInfo("Zyphra/Zamba2-7B-instruct"),
"MiMoForCausalLM": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL",
trust_remote_code=True),
"HunYuanMoEV1ForCausalLM": _HfExamplesInfo("tencent/Hunyuan-A13B-Instruct"),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to register dense model here as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Add HunYuanDenseV1ForCausalLM

Note:
We are currently working on some HF model governance. The architecture corresponding to the previously open Dense model is called HunYuanForCausalLM. The subsequent Dense model will be called HunYuanDenseV1ForCausalLM. If you want to run the previous model, you need to change the architecture. This PR does not include adaptation of the previous model.

Comment on lines 171 to 173
is_neox_style = True
if quant_config is not None and quant_config.get_name() == "gguf":
is_neox_style = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is only for Llama GGUF models.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Remove related code

Comment on lines 286 to 288
is_neox_style = True
if quant_config is not None and quant_config.get_name() == "gguf":
is_neox_style = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Remove related code

Comment on lines 666 to 676
split_params_mapping = [
(".gate_up_proj", ".gate_and_up_proj", 2, [(1, 1), (0, 1)], None),
(
".qkv_proj",
".qkv_proj",
num_attention_heads + num_kv_heads * 2,
[("q", num_attention_heads), ("k", num_kv_heads),
("v", num_kv_heads)],
self._split_qkv_weight,
),
]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have to split weights which have been stacked? I think we should be able to load it directly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because of some historical reasons. The weights that were originally combined together need to be split and then spliced ​​because the specific layout does not meet the requirements, so the code here needs to be retained

Signed-off-by: aiyiwang <aiyiwang@tencent.com>
Signed-off-by: aiyiwang <aiyiwang@tencent.com>
Signed-off-by: aiyiwang <aiyiwang@tencent.com>
Copy link

mergify bot commented Jun 30, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @aiyiwang2025.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jun 30, 2025
Signed-off-by: aiyiwang <aiyiwang@tencent.com>
@mergify mergify bot removed the needs-rebase label Jun 30, 2025
aiyiwang2025 and others added 2 commits June 30, 2025 14:17
Signed-off-by: aiyiwang <aiyiwang@tencent.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
@mergify mergify bot added the ci/build label Jun 30, 2025
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Copy link
Collaborator

@jeejeelee jeejeelee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your contribution. Considering that I have confirmed with the PR author that the generated results can be aligned, let's merge this PR first. Related improvements can be completed in subsequent PRs

@jeejeelee jeejeelee added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 30, 2025
@jeejeelee jeejeelee enabled auto-merge (squash) June 30, 2025 13:24
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
@celsowm
Copy link

celsowm commented Jun 30, 2025

hi ! will HunYuanConfig be included too ?

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
@jeejeelee jeejeelee removed the ready ONLY add when PR is ready to merge/full CI is needed label Jul 1, 2025
jeejeelee added 2 commits July 1, 2025 08:50
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
@jeejeelee jeejeelee added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 1, 2025
@vllm-bot vllm-bot merged commit ecad851 into vllm-project:main Jul 1, 2025
71 of 75 checks passed
koiker pushed a commit to koiker/vllm that referenced this pull request Jul 1, 2025
Signed-off-by: aiyiwang <aiyiwang@tencent.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Co-authored-by: quinnrong <quinnrong@tencent.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Rafael Marcelino Koike <rafael.koike@oracle.com>
@huachenheli
Copy link
Contributor

huachenheli commented Jul 2, 2025

Looks like OpenAI-Compatible Tool Use test is failing due to this PR's change to rope scaling:
cc. @jeejeelee @Isotr0py

[2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519] EngineCore failed to start.
--
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519] Traceback (most recent call last):
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 510, in run_engine_core
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]     engine_core = EngineCoreProc(*args, **kwargs)
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 394, in __init__
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]     super().__init__(vllm_config, executor_class, log_stats,
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 75, in __init__
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]     self.model_executor = executor_class(vllm_config)
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/executor_base.py", line 53, in __init__
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]     self._init_executor()
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/uniproc_executor.py", line 48, in _init_executor
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]     self.collective_rpc("load_model")
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/uniproc_executor.py", line 57, in collective_rpc
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]     answer = run_method(self.driver_worker, method, args, kwargs)
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]   File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 2716, in run_method
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]     return func(*args, **kwargs)
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]            ^^^^^^^^^^^^^^^^^^^^^
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_worker.py", line 185, in load_model
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]     self.model_runner.load_model()
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 1793, in load_model
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]     self.model = model_loader.load_model(
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]                  ^^^^^^^^^^^^^^^^^^^^^^^^
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/model_loader/base_loader.py", line 38, in load_model
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]     model = initialize_model(vllm_config=vllm_config,
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/model_loader/utils.py", line 65, in initialize_model
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]     return model_class(vllm_config=vllm_config, prefix=prefix)
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/internlm2.py", line 331, in __init__
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]     self.model = model_type(vllm_config=vllm_config,
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]   File "/usr/local/lib/python3.12/dist-packages/vllm/compilation/decorators.py", line 152, in __init__
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]     old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/internlm2.py", line 270, in __init__
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]     self.start_layer, self.end_layer, self.layers = make_layers(
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]                                                     ^^^^^^^^^^^^
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/utils.py", line 640, in make_layers
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]     maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/internlm2.py", line 272, in <lambda>
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]     lambda prefix: layer_type(
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]                    ^^^^^^^^^^^
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/internlm2.py", line 203, in __init__
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]     self.attention = InternLM2Attention(
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]                      ^^^^^^^^^^^^^^^^^^^
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/internlm2.py", line 134, in __init__
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]     self.rotary_emb = get_rope(
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]                       ^^^^^^^^^
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/rotary_embedding.py", line 1967, in get_rope
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]     scaling_alpha = rope_scaling["alpha"]
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519]                     ~~~~~~~~~~~~^^^^^^^^^
  | [2025-07-01T21:39:01Z] ERROR 07-01 14:39:01 [core.py:519] KeyError: 'alpha'

@DarkLight1337
Copy link
Member

Should be fixed by #20343, sorry for merging this

CSWYF3634076 pushed a commit to CSWYF3634076/vllm that referenced this pull request Jul 2, 2025
Signed-off-by: aiyiwang <aiyiwang@tencent.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Co-authored-by: quinnrong <quinnrong@tencent.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
avigny pushed a commit to avigny/vllm that referenced this pull request Jul 31, 2025
Signed-off-by: aiyiwang <aiyiwang@tencent.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Co-authored-by: quinnrong <quinnrong@tencent.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: avigny <47987522+avigny@users.noreply.github.com>
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
Signed-off-by: aiyiwang <aiyiwang@tencent.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Co-authored-by: quinnrong <quinnrong@tencent.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
googlercolin pushed a commit to googlercolin/vllm that referenced this pull request Aug 29, 2025
Signed-off-by: aiyiwang <aiyiwang@tencent.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Co-authored-by: quinnrong <quinnrong@tencent.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build documentation Improvements or additions to documentation ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants