Skip to content

Conversation

ShareLer
Copy link
Contributor

@ShareLer ShareLer commented May 30, 2025

Checklist Before Starting

  • Search for similar PR(s).

What does this PR do?

Fix megatron model merger.

High-Level Design

Demonstrate the high-level design if this PR is complex.

Specific Changes

  • Fix get rank method to support just TP.
  • Fix state_dict keys after convert.
  • Add mla/moe convert support.

API

Demonstrate how the API changes if any.

Usage Example

Provide usage example(s) for easier usage.

# Add code snippet or script demonstrating how to use this 

Test

Test with Qwen3-8B and Qwen2.5-7B.

Additional Info.

Checklist Before Submitting

  • Read the Contribute Guide.
  • Apply pre-commit checks.
  • Add [BREAKING] to the PR title if it breaks any API.
  • Update the documentation about your changes in the docs.
  • Add CI test(s) if necessary.

Signed-off-by: ShareLer <ShareLe@163.com>
@vermouth1992 vermouth1992 requested a review from ETOgaosion May 30, 2025 08:49
@ETOgaosion
Copy link
Collaborator

ETOgaosion commented May 31, 2025

@ShareLer Thanks a lot for helping us to fix this, it helps a lot~

Could you briefly point out what causes vllm inference failure? Seems also involved a lot of refactorization. Is it caused by the missing parameters transferring?

Signed-off-by: ShareLer <ShareLe@163.com>
@ShareLer
Copy link
Contributor Author

ShareLer commented Jun 1, 2025

@ShareLer Thanks a lot for helping us to fix this, it helps a lot~

Could you briefly point out what causes vllm inference failure? Seems also involved a lot of refactorization. Is it caused by the missing parameters transferring?

Three main reasons:

  1. The name of the converted layer has not been modified in _merge_state_dicts().
    model.decoder.layers.xxx in converted ckpt, but the actual is supposed to be model.layers.xxx.
    This is also the reason for the failure in the mentioned issue.

  2. The name of qkv in the attention layer is error after converted.
    linear_qkv is converted to linear_q/linear_k/linear_v, but in reality it should be q_proj/k_proj/v_proj

  3. The weight of the final output layer is processed incorrectly.
    When is_value_model=False, the output_layer is ColumnParallelLinear, but the weights of different TP were not merged in _merge_state_dicts().
    (When submitting for the first time, my method ignored the value_model, resulting in the failure of CI. It has just been fixed.)

Copy link
Collaborator

@eric-haibin-lin eric-haibin-lin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to add a test that reproduces the issue?

@ShareLer
Copy link
Contributor Author

ShareLer commented Jun 1, 2025

is it possible to add a test that reproduces the issue?

You can reproduce this problem very simply by using the CI script (like job 'e2e_ppo_trainer_megatron-qwen3' in e2e_ppo_trainer_megatron.yml) just change command option in merger:
You need change the test operation in python scripts/model_merger.py test --backend megatron to merge.

There were no problems in the previous CI test because different logics were used in the test and merge options:
First of all, they all obtained the merged weights through the _merge_state_dicts() method. However, it should be noted that there are some problems with state_dicts at this time (the three problems in the previous reply).
Next, in the merge option, this problematic state_dicts was directly saved as the final ckpt. But these problematic layer names were corrected in the test option (remove the decoder and correct the name of the qkv) which used in CI.

@ETOgaosion
Copy link
Collaborator

ETOgaosion commented Jun 3, 2025

The name of the converted layer has not been modified in _merge_state_dicts().
model.decoder.layers.xxx in converted ckpt, but the actual is supposed to be model.layers.xxx.
This is also the reason for the failure in the mentioned issue.

@ShareLer Could you add some assertions for checking whether naming prefix is valid as model_merger's request?

Or whether it can have some robustness to get elements from string, like fetch the layer_index from [-3] index of split strings?

@ShareLer
Copy link
Contributor Author

ShareLer commented Jun 3, 2025

The name of the converted layer has not been modified in _merge_state_dicts().
model.decoder.layers.xxx in converted ckpt, but the actual is supposed to be model.layers.xxx.
This is also the reason for the failure in the mentioned issue.

@ShareLer Could you add some assertion for checking whether naming prefix is valid as model_merger's request?

Or whether it can have some robustness to get elements from string, like fetch the layer_index from [-3] index of split strings?

Sorry, I don't quite understand what you mean.
Do you mean that we should perform some checks on the keys in merged_state_dict before saving the merged ckpt to ensure its validity?

@@ -444,28 +485,28 @@ def _merge_state_dicts(self, model_state_dict_lst: list[list[dict]], tp_size: in
print("skip lm_head and reward_head loading because of tie_word_embeddings")
continue

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe at here we do some check on the mcore's key to check whether it's valid for _replace_name to work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can add an assert to the result of _replace_name to ensure that it is not None, because we define the layer transformation relationship in self.params_mapping, and normally the result should not appear None.
Do you think this is feasible?

Copy link
Collaborator

@ETOgaosion ETOgaosion Jun 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I didn't make myself clear(qaq). I can help to add some checking and warnings~

ShareLer and others added 3 commits June 3, 2025 16:56
Signed-off-by: ShareLer <ShareLe@163.com>
Signed-off-by: ShareLer <ShareLe@163.com>
@ShareLer
Copy link
Contributor Author

ShareLer commented Jun 5, 2025

@ETOgaosion Hi, I have added additional judgments for embedding layer and output_layer on the basis of your code.

@ETOgaosion
Copy link
Collaborator

@ShareLer Thanks for helping, I was about to test on my machine~

@ETOgaosion ETOgaosion merged commit cc9bc3f into volcengine:main Jun 9, 2025
33 of 34 checks passed
yellowbee686 pushed a commit to yellowbee686/verl that referenced this pull request Jun 10, 2025
### Checklist Before Starting

- [x] Search for similar PR(s).

### What does this PR do?

Fix megatron model merger.

### High-Level Design

> Demonstrate the high-level design if this PR is complex.

### Specific Changes

- Fix get rank method to support just TP.
- Fix state_dict keys after convert.
- Add mla/moe convert support.

### API

> Demonstrate how the API changes if any.

### Usage Example

> Provide usage example(s) for easier usage.

```python
# Add code snippet or script demonstrating how to use this 
```

### Test

Test with Qwen3-8B and Qwen2.5-7B.

### Additional Info.

- **Issue Number**: Fixes issue volcengine#1757
- **Training**: [Note which backend this PR will affect: FSDP, Megatron,
both, or none]
- **Inference**: [Note which backend this PR will affect: vLLM, SGLang,
both, or none]

### Checklist Before Submitting

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting).
- [ ] Add `[BREAKING]` to the PR title if it breaks any API.
- [ ] Update the documentation about your changes in the
[docs](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add CI test(s) if necessary.

---------

Signed-off-by: ShareLer <ShareLe@163.com>
Co-authored-by: ETOgaosion <gaoziyuan19@mails.ucas.ac.cn>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants