Skip to content

Conversation

ISEEKYAN
Copy link
Contributor

achieve 74.3 at gsm8k, while moonlight reported as 77.4

still WIP with the performance diff

@ISEEKYAN ISEEKYAN changed the title Mcore moonlight [mcore] moonlight (small model with deepseekv3 arch) Apr 28, 2025
@BearBiscuit05
Copy link
Collaborator

Thank you for the great work. Could you provide a dpsk-v3 script in the example for testing?

@ISEEKYAN
Copy link
Contributor Author

Thank you for the great work. Could you provide a dpsk-v3 script in the example for testing?

just added a script

@jinqinn
Copy link
Contributor

jinqinn commented May 22, 2025

@ISEEKYAN i ran dspk-v3 with this patch, but it failed. does this patch support training for dpsk-v3 ?

@ISEEKYAN
Copy link
Contributor Author

@ISEEKYAN i ran dspk-v3 with this patch, but it failed. does this patch support training for dpsk-v3 ?

What is the error info? It is supposed to support dpsk-v3 except for the mtp part for now. I will update the PR to support the full version.

@jinqinn
Copy link
Contributor

jinqinn commented May 22, 2025

@ISEEKYAN i ran dspk-v3 with this patch, but it failed. does this patch support training for dpsk-v3 ?

What is the error info? It is supposed to support dpsk-v3 except for the mtp part for now. I will update the PR to support the full version.

Thank you for your reply.
I disabled pack_seqs because of _apply_rotary_pos_emb_bshd error:

output = forward_fn(model, input_ids, attention_mask, position_ids, sequence_parallel=self.tf_config.sequence_parallel, pack_seqs=False)

Python Dependencies and Versions:
vllm 0.8.2
megatron-lm 0.12.0

Error Log:

File "/verl/verl/models/mcore/model_forward.py", line 41, in gptmodel_forward
    output = model(input_ids=new_input_ids, attention_mask=new_attention_mask, position_ids=new_position_ids)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  
  ...

  File "/Megatron-LM/megatron/core/models/gpt/gpt_model.py", line 334, in forward
    hidden_states = self.decoder(
                    ^^^^^^^^^^^^^
  File "/Megatron-LM/megatron/core/transformer/transformer_block.py", line 529, in forward
    hidden_states, context = layer(
                             ^^^^^^
  File "/Megatron-LM/megatron/core/transformer/transformer_layer.py", line 786, in __call__
    return super(MegatronModule, self).__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Megatron-LM/megatron/core/transformer/transformer_layer.py", line 389, in forward
    pre_mlp_layernorm_output, residual, context = self._forward_attention(*args, **kwargs)
                                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Megatron-LM/megatron/core/transformer/transformer_layer.py", line 449, in _forward_attention
    attention_output_with_bias = self.self_attention(
                                 ^^^^^^^^^^^^^^^^^^^^
  File "/Megatron-LM/megatron/core/transformer/multi_latent_attention.py", line 202, in forward
    core_attn_out = self.core_attention(
                    ^^^^^^^^^^^^^^^^^^^^
    ...
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Megatron-LM/megatron/core/extensions/transformer_engine.py", line 809, in forward
    core_attn_out = super().forward(
                    ^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py", line 1010, in forward
    raise ValueError(
ValueError: No dot product attention backend is available for the provided inputs. Please run with NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=2 to find out the reasons for disabling all backends.

@ISEEKYAN could u please take a look at this error and help me fix it ?

@ISEEKYAN
Copy link
Contributor Author

@ISEEKYAN i ran dspk-v3 with this patch, but it failed. does this patch support training for dpsk-v3 ?

What is the error info? It is supposed to support dpsk-v3 except for the mtp part for now. I will update the PR to support the full version.

Thank you for your reply. I disabled pack_seqs because of _apply_rotary_pos_emb_bshd error:

output = forward_fn(model, input_ids, attention_mask, position_ids, sequence_parallel=self.tf_config.sequence_parallel, pack_seqs=False)

Python Dependencies and Versions: vllm 0.8.2 megatron-lm 0.12.0

Error Log:

File "/verl/verl/models/mcore/model_forward.py", line 41, in gptmodel_forward
    output = model(input_ids=new_input_ids, attention_mask=new_attention_mask, position_ids=new_position_ids)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  
  ...

  File "/Megatron-LM/megatron/core/models/gpt/gpt_model.py", line 334, in forward
    hidden_states = self.decoder(
                    ^^^^^^^^^^^^^
  File "/Megatron-LM/megatron/core/transformer/transformer_block.py", line 529, in forward
    hidden_states, context = layer(
                             ^^^^^^
  File "/Megatron-LM/megatron/core/transformer/transformer_layer.py", line 786, in __call__
    return super(MegatronModule, self).__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Megatron-LM/megatron/core/transformer/transformer_layer.py", line 389, in forward
    pre_mlp_layernorm_output, residual, context = self._forward_attention(*args, **kwargs)
                                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Megatron-LM/megatron/core/transformer/transformer_layer.py", line 449, in _forward_attention
    attention_output_with_bias = self.self_attention(
                                 ^^^^^^^^^^^^^^^^^^^^
  File "/Megatron-LM/megatron/core/transformer/multi_latent_attention.py", line 202, in forward
    core_attn_out = self.core_attention(
                    ^^^^^^^^^^^^^^^^^^^^
    ...
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Megatron-LM/megatron/core/extensions/transformer_engine.py", line 809, in forward
    core_attn_out = super().forward(
                    ^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py", line 1010, in forward
    raise ValueError(
ValueError: No dot product attention backend is available for the provided inputs. Please run with NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=2 to find out the reasons for disabling all backends.

@ISEEKYAN could u please take a look at this error and help me fix it ?

The mcore version was a pre-release 0.12 when I initially make this PR, and there was a known bug at multi_latent_attention.py where the sequence packing did not work with MLA. At that time I patched the MLA code as the following, but I found the v0.12 MLA has changed now. I will update this PR with mcore v0.12. Feel free the try with the patch.

image

@ISEEKYAN
Copy link
Contributor Author

@ISEEKYAN Thanks for your support! I enabled pack_seqs, and it works fine on mcore 0.12 after applying the patch. However, I still encounter another error:

File "/verl/verl/models/mcore/model_forward.py", line 41, in gptmodel_forward
    output = model(input_ids=new_input_ids, attention_mask=new_attention_mask, position_ids=new_position_ids)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  
  ...

  (TaskRunner pid=236671)   File "/verl/verl/single_controller/ray/base.py", line 625, in func
(TaskRunner pid=236671)     return getattr(self.worker_dict[key], name)(*args, **kwargs)
(TaskRunner pid=236671)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=236671)   File "/verl/verl/single_controller/base/decorator.py", line 534, in inner
(TaskRunner pid=236671)     return func(*args, **kwargs)
(TaskRunner pid=236671)            ^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=236671)   File "/verl/verl/utils/debug/performance.py", line 78, in f
(TaskRunner pid=236671)     return self.log(decorated_function, *args, **kwargs)
(TaskRunner pid=236671)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=236671)   File "/verl/verl/utils/debug/performance.py", line 88, in log
(TaskRunner pid=236671)     output = func(*args, **kwargs)
(TaskRunner pid=236671)              ^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=236671)   File "/verl/verl/workers/megatron_workers.py", line 467, in compute_log_prob
(TaskRunner pid=236671)     old_log_probs, entropys = self.actor.compute_log_prob(data=output, calculate_entropy=True)
(TaskRunner pid=236671)                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=236671)   File "/verl/verl/utils/debug/performance.py", line 78, in f
(TaskRunner pid=236671)     return self.log(decorated_function, *args, **kwargs)
(TaskRunner pid=236671)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=236671)   File "/verl/verl/utils/debug/performance.py", line 88, in log
(TaskRunner pid=236671)     output = func(*args, **kwargs)
(TaskRunner pid=236671)              ^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=236671)   File "/verl/verl/workers/actor/megatron_actor.py", line 189, in compute_log_prob
(TaskRunner pid=236671)     output = self.forward_backward_batch(data, forward_only=True, post_process_fn=compute_logprobs_fn, calculate_entropy=calculate_entropy)
(TaskRunner pid=236671)              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=236671)   File "/verl/verl/workers/actor/megatron_actor.py", line 389, in forward_backward_batch
(TaskRunner pid=236671)     losses_reduced = forward_backward_func(
(TaskRunner pid=236671)                      ^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=236671)   File "/Megatron-LM/megatron/core/pipeline_parallel/schedules.py", line 1847, in forward_backward_pipelining_without_interleaving
(TaskRunner pid=236671)     output_tensor, num_tokens = forward_step(
(TaskRunner pid=236671)                                 ^^^^^^^^^^^^^
(TaskRunner pid=236671)   File "/Megatron-LM/megatron/core/pipeline_parallel/schedules.py", line 277, in forward_step
(TaskRunner pid=236671)     output_tensor, loss_func = forward_step_func(data_iterator, model)
(TaskRunner pid=236671)                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=236671)   File "/verl/verl/workers/actor/megatron_actor.py", line 371, in forward_step
(TaskRunner pid=236671)     output = forward_fn(model, input_ids, attention_mask, position_ids, sequence_parallel=self.tf_config.sequence_parallel)
(TaskRunner pid=236671)              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=236671)   File "/verl/verl/models/mcore/model_forward.py", line 30, in gptmodel_forward
(TaskRunner pid=236671)     output_orig = model(
(TaskRunner pid=236671)                   ^^^^^^
(TaskRunner pid=236671)   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
(TaskRunner pid=236671)     return self._call_impl(*args, **kwargs)
(TaskRunner pid=236671)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=236671)   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
(TaskRunner pid=236671)     return forward_call(*args, **kwargs)
(TaskRunner pid=236671)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=236671)   File "/Megatron-LM/megatron/core/distributed/data_parallel_base.py", line 22, in forward
(TaskRunner pid=236671)     return self.module(*inputs, **kwargs)
(TaskRunner pid=236671)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=236671)   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
(TaskRunner pid=236671)     return self._call_impl(*args, **kwargs)
(TaskRunner pid=236671)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=236671)   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
(TaskRunner pid=236671)     return forward_call(*args, **kwargs)
(TaskRunner pid=236671)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=236671)   File "/Megatron-LM/megatron/core/transformer/module.py", line 178, in forward
(TaskRunner pid=236671)     outputs = self.module(*inputs, **kwargs)
(TaskRunner pid=236671)               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=236671)   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
(TaskRunner pid=236671)     return self._call_impl(*args, **kwargs)
(TaskRunner pid=236671)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=236671)   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
(TaskRunner pid=236671)     return forward_call(*args, **kwargs)
(TaskRunner pid=236671)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=236671)   File "/Megatron-LM/megatron/core/models/gpt/gpt_model.py", line 334, in forward
(TaskRunner pid=236671)     hidden_states = self.decoder(
(TaskRunner pid=236671)                     ^^^^^^^^^^^^^
(TaskRunner pid=236671)   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
(TaskRunner pid=236671)     return self._call_impl(*args, **kwargs)
(TaskRunner pid=236671)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=236671)   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
(TaskRunner pid=236671)     return forward_call(*args, **kwargs)
(TaskRunner pid=236671)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=236671)   File "/Megatron-LM/megatron/core/transformer/transformer_block.py", line 515, in forward
(TaskRunner pid=236671)     hidden_states, context = layer(
(TaskRunner pid=236671)                              ^^^^^^
(TaskRunner pid=236671)   File "/Megatron-LM/megatron/core/transformer/transformer_layer.py", line 740, in __call__
(TaskRunner pid=236671)     return super(MegatronModule, self).__call__(*args, **kwargs)
(TaskRunner pid=236671)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=236671)   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
(TaskRunner pid=236671)     return self._call_impl(*args, **kwargs)
(TaskRunner pid=236671)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=236671)   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
(TaskRunner pid=236671)     return forward_call(*args, **kwargs)
(TaskRunner pid=236671)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=236671)   File "/Megatron-LM/megatron/core/transformer/transformer_layer.py", line 374, in forward
(TaskRunner pid=236671)     pre_mlp_layernorm_output, residual, context = self._forward_attention(*args, **kwargs)
(TaskRunner pid=236671)                                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=236671)   File "/Megatron-LM/megatron/core/transformer/transformer_layer.py", line 428, in _forward_attention
(TaskRunner pid=236671)     attention_output_with_bias = self.self_attention(
(TaskRunner pid=236671)                                  ^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=236671)   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
(TaskRunner pid=236671)     return self._call_impl(*args, **kwargs)
(TaskRunner pid=236671)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=236671)   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
(TaskRunner pid=236671)     return forward_call(*args, **kwargs)
(TaskRunner pid=236671)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=236671)   File "/Megatron-LM/megatron/core/transformer/multi_latent_attention.py", line 193, in forward
(TaskRunner pid=236671)     core_attn_out = self.core_attention(
(TaskRunner pid=236671)                     ^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=236671)   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
(TaskRunner pid=236671)     return self._call_impl(*args, **kwargs)
(TaskRunner pid=236671)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=236671)   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
(TaskRunner pid=236671)     return forward_call(*args, **kwargs)
(TaskRunner pid=236671)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(TaskRunner pid=236671)   File "/Megatron-LM/megatron/core/extensions/transformer_engine.py", line 809, in forward
(TaskRunner pid=236671)     core_attn_out = super().forward(
(TaskRunner pid=236671)                     ^^^^^^^^^^^^^^^^
(TaskRunner pid=236671)   File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py", line 1010, in forward
(TaskRunner pid=236671)     raise ValueError(
(TaskRunner pid=236671) ValueError: No dot product attention backend is available for the provided inputs. Please run with NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=2 to find out the reasons for disabling all backends.

It might be because TE version mismatch, I used docker://whatcanyousee/verl:nvpytorch-cu128-vllm0.8.3-sglang0.4.5-mcore0.12.0-te2.2-dev last month, please try with te2.2 image or te2.3 image

@ISEEKYAN
Copy link
Contributor Author

ISEEKYAN commented May 23, 2025

I update the PR upon the latest version of verl, to use the optimzation of EP/avoid_pad_logits, reducing the GPU memory consumption a lot. Once merged with #1638, it will be possible to train Moonlight-16B-A3B with 1node/8GPUs.

To train dpskv3 671B, there remains a few TODOs:

@jinqinn @duomicoding it will be very helpful with your participation

@ISEEKYAN ISEEKYAN marked this pull request as ready for review May 23, 2025 12:32
@ISEEKYAN
Copy link
Contributor Author

update: my last run with 2nodes achieves 87 at gsm8k, and the training is relatively more stable than my experiments last month, now training moonlight is ready.
see the wandb log at url

@duomicoding
Copy link

@ISEEKYAN hello, I can run dspk-v3 successfully using this PR, but the reward score abnormality. could you please show some training log infomation for me? such as the loss curve or normal reward curve, and so on. I need those information to debug my code.

My training was not perfect either. As the steps grow, the ppo_kl increases to an unacceptable number. I guess there are some mistakes when generate the transformer_config, we could find it out togather.

image

@ISEEKYAN
Thank you for your greate work! I can reproduce your result, the ppo_kl increases like yours, the reward declines gradually during the late stages of training. May I ask if you have any new discoveries about this question?

@ISEEKYAN
Copy link
Contributor Author

@ISEEKYAN hello, I can run dspk-v3 successfully using this PR, but the reward score abnormality. could you please show some training log infomation for me? such as the loss curve or normal reward curve, and so on. I need those information to debug my code.

My training was not perfect either. As the steps grow, the ppo_kl increases to an unacceptable number. I guess there are some mistakes when generate the transformer_config, we could find it out togather.
image

@ISEEKYAN Thank you for your greate work! I can reproduce your result, the ppo_kl increases like yours, the reward declines gradually during the late stages of training. May I ask if you have any new discoveries about this question?

Please check the latest curve at url, now the ppo_kl curve is much better than the April version

@ISEEKYAN ISEEKYAN mentioned this pull request May 26, 2025
6 tasks
@jinqinn
Copy link
Contributor

jinqinn commented May 26, 2025

OOM error:

actor_id=9af19a7927c3087b60d0ca2c01000000, repr=<main_ppo.TaskRunner object at 0x7f9e7ed66c80>)
  File "/workspace/verl/verl/trainer/main_ppo.py", line 192, in run
    trainer.init_workers()
  File "/workspace/verl/verl/trainer/ppo/ray_trainer.py", line 740, in init_workers
    self.ref_policy_wg.init_model()
  File "/workspace/verl/verl/single_controller/ray/base.py", line 49, in func
    output = ray.get(output)
ray.exceptions.ActorDiedError: The actor died unexpectedly before finishing this task.
        class_name: create_colocated_worker_cls.<locals>.WorkerDict
        actor_id: 2827586afada19fe1e55a00901000000
        pid: 1361797
        name: lK2Oa7WorkerDict_0:1
        namespace: 4788b225-7f99-4f3c-91f7-c3ce0d71986d
        ip: 172.21.17.94
The actor is dead because its worker process has died. Worker exit type: SYSTEM_ERROR Worker exit detail: Worker unexpectedly exits with a connection error code 2. End of file. There are some potential root causes. (1) The process is killed by SIGKILL by OOM killer due to high memory usage. (2) ray stop --force is called. (3) The worker is crashed unexpectedly due to SIGSEGV or other unexpected errors.

@duomicoding
Copy link

@ISEEKYAN hello, I can run dspk-v3 successfully using this PR, but the reward score abnormality. could you please show some training log infomation for me? such as the loss curve or normal reward curve, and so on. I need those information to debug my code.

My training was not perfect either. As the steps grow, the ppo_kl increases to an unacceptable number. I guess there are some mistakes when generate the transformer_config, we could find it out togather.
image

@ISEEKYAN Thank you for your greate work! I can reproduce your result, the ppo_kl increases like yours, the reward declines gradually during the late stages of training. May I ask if you have any new discoveries about this question?

Please check the latest curve at url, now the ppo_kl curve is much better than the April version

@ISEEKYAN Thank you for your reply, for the ppo_kl increasing curve and the reward declining curve, may I ask if you have any new discoveries or bugs?

@ISEEKYAN
Copy link
Contributor Author

@ISEEKYAN hello, I can run dspk-v3 successfully using this PR, but the reward score abnormality. could you please show some training log infomation for me? such as the loss curve or normal reward curve, and so on. I need those information to debug my code.

My training was not perfect either. As the steps grow, the ppo_kl increases to an unacceptable number. I guess there are some mistakes when generate the transformer_config, we could find it out togather.
image

@ISEEKYAN Thank you for your greate work! I can reproduce your result, the ppo_kl increases like yours, the reward declines gradually during the late stages of training. May I ask if you have any new discoveries about this question?

Please check the latest curve at url, now the ppo_kl curve is much better than the April version

@ISEEKYAN Thank you for your reply, for the ppo_kl increasing curve and the reward declining curve, may I ask if you have any new discoveries or bugs?

With the latest updates, the ppo_kl increases very slowly and no reward decline any more.
check val-core/openai/gsm8k/reward/mean@1 for reward curve and actor/ppo_kl for ppo_kl. Note that there are three bad curves, which are April version of exps. I thought it might be fixed by setting moe_router_dtype="fp64",disable_bf16_reduced_precision_matmul=True, in config_converter.py, but I am not sure.

@ETOgaosion ETOgaosion mentioned this pull request May 27, 2025
6 tasks
dtype=dtype,
use_cpu_initialization=False,
add_bias_linear=False,
attention_backend=AttnBackend.fused,
Copy link
Collaborator

@ccclyu ccclyu May 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is AttnBackend.fused specific to deepseek v3 model? is AttnBackend.auto enough here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When feed with AttnBackend.auto, the TE would use flash, but flash is not implemented for MLA, the error info is
ValueError: No dot product attention backend is available for the provided inputs. Please run with NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=2 to find out the reasons for disabling all backends.

data.max_response_length=512 \
data.filter_overlong_prompts=True \
data.truncation='error' \
+data.trust_remote_code=True \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should trust_remote_code be set in the model per ppo_megatron_trainer.yaml?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh data preprocessing might need this as well. please ignore this if i misunderstand.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this another topic beyond supporting moonlight? would it be better if we commit another small PR for the config file modification?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree. we should track any change in the config and keep it consistent.

# limitations under the License.

# there is some bug in mcore 0.12, so we need to patch it
# 1. `get_query_key_value_tensors` in `multi_latent_attention.py` works wrong when packed_seq_params is not None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious whether there is issue opening in the megatron lm repo so that we can track the patch fix accordingly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is supposed to be fixed in 0.13

Copy link
Collaborator

@ccclyu ccclyu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!!

@vermouth1992 vermouth1992 merged commit be47ac4 into volcengine:main May 28, 2025
35 of 36 checks passed
vermouth1992 pushed a commit that referenced this pull request Jun 5, 2025
### Checklist Before Starting

- [x] Search for similar PR(s).

### What does this PR do?

support training with deepseekv3 671B
support MTP on top of #1284 

now it is functional ready for 671B, still lacking of practice

> Add one-line overview of what this PR aims to achieve or accomplish. 

### High-Level Design

> Demonstrate the high-level design if this PR is complex.

### Specific Changes

> List the specific changes.

### API

> Demonstrate how the API changes if any.

### Usage Example

> Provide usage example(s) for easier usage.

```python
# Add code snippet or script demonstrating how to use this 
```

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluatuion results, etc.

### Additional Info.

- **Issue Number**: Fixes issue # or discussion # if any.
- **Training**: [Note which backend this PR will affect: FSDP, Megatron,
both, or none]
- **Inference**: [Note which backend this PR will affect: vLLM, SGLang,
both, or none]

### Checklist Before Submitting

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting).
- [ ] Add `[BREAKING]` to the PR title if it breaks any API.
- [ ] Update the documentation about your changes in the
[docs](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add CI test(s) if necessary.
yellowbee686 pushed a commit to yellowbee686/verl that referenced this pull request Jun 6, 2025
### Checklist Before Starting

- [x] Search for similar PR(s).

### What does this PR do?

support training with deepseekv3 671B
support MTP on top of volcengine#1284 

now it is functional ready for 671B, still lacking of practice

> Add one-line overview of what this PR aims to achieve or accomplish. 

### High-Level Design

> Demonstrate the high-level design if this PR is complex.

### Specific Changes

> List the specific changes.

### API

> Demonstrate how the API changes if any.

### Usage Example

> Provide usage example(s) for easier usage.

```python
# Add code snippet or script demonstrating how to use this 
```

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluatuion results, etc.

### Additional Info.

- **Issue Number**: Fixes issue # or discussion # if any.
- **Training**: [Note which backend this PR will affect: FSDP, Megatron,
both, or none]
- **Inference**: [Note which backend this PR will affect: vLLM, SGLang,
both, or none]

### Checklist Before Submitting

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting).
- [ ] Add `[BREAKING]` to the PR title if it breaks any API.
- [ ] Update the documentation about your changes in the
[docs](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add CI test(s) if necessary.
wwwjn pushed a commit to wwwjn/verl that referenced this pull request Jun 10, 2025
achieve 74.3 at gsm8k, while moonlight reported as 77.4

still WIP with the performance diff
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants