Skip to content

Conversation

none0663
Copy link
Contributor

What does this PR do?

Fix Configuration for Micro Batch Size in Megatron's Ref Policy

High-Level Design

This pull request addresses an issue with the micro batch size configuration in the ref policy of Megatron. The default ppo_megatron_trainer.yaml only includes two configurations: log_prob_micro_batch_size and log_prob_micro_batch_size_per_gpu.

log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
log_prob_micro_batch_size_per_gpu: null

However, in megatron_workers.py, the required configuration is ref.log_prob_micro_batch_size_per_gpu
micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu
data.meta_info["micro_batch_size"] = micro_batch_size

or in megatron_actor.py the required configuration is ref.ppo_micro_batch_size_per_gpu,
if data.meta_info.get("micro_batch_size", None) is not None:
batch_size = data.meta_info["micro_batch_size"]
else:
batch_size = self.config.ppo_micro_batch_size_per_gpu

which are not directly related to ppo_micro_batch_size.

To resolve this, I have made modifications to the configuration calculations and added raise ValueError statements to ensure that the necessary parameters are correctly defined.

This update ensures that the required parameters are properly handled, preventing runtime errors and improving the overall robustness of the training process.

Changes Made:

  • Modified the configuration calculations in megatron_workers.py.

  • Added raise ValueError statements to check for the presence of log_prob_micro_batch_size_per_gpu and ppo_micro_batch_size_per_gpu.

else:
if self.config.ref.get("log_prob_micro_batch_size_per_gpu", None):
self.config.ref.ppo_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size_per_gpu
elif self.config.ref.get("ppo_micro_batch_size_per_gpu", None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for contribution!

I think that here is a typo, so we may not need to consider ppo_micro_batch_size_per_gpu, you can simply judge the key above~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix, del the ppo_micro_batch_size_per_gpu

@ETOgaosion ETOgaosion merged commit 99e749a into volcengine:main May 28, 2025
20 checks passed
ETOgaosion pushed a commit to Jianbing-D/verl that referenced this pull request Jun 8, 2025
…cengine#1700)

### What does this PR do?

 Fix Configuration for Micro Batch Size in Megatron's Ref Policy

### High-Level Design
This pull request addresses an issue with the micro batch size
configuration in the ref policy of Megatron. The default
ppo_megatron_trainer.yaml only includes two configurations:
log_prob_micro_batch_size and log_prob_micro_batch_size_per_gpu.

https://github.com/volcengine/verl/blob/54c9b7364c2d188b2ba4107404cfa3c2b446df19/verl/trainer/config/ppo_megatron_trainer.yaml#L119-L120
However, in `megatron_workers.py`, the required configuration is
ref.log_prob_micro_batch_size_per_gpu

https://github.com/volcengine/verl/blob/54c9b7364c2d188b2ba4107404cfa3c2b446df19/verl/workers/megatron_workers.py#L517-L518
or in `megatron_actor.py ` the required configuration is
ref.ppo_micro_batch_size_per_gpu,

https://github.com/volcengine/verl/blob/54c9b7364c2d188b2ba4107404cfa3c2b446df19/verl/workers/actor/megatron_actor.py#L271-L274

which are not directly related to ppo_micro_batch_size.

To resolve this, I have made modifications to the configuration
calculations and added raise ValueError statements to ensure that the
necessary parameters are correctly defined.

This update ensures that the required parameters are properly handled,
preventing runtime errors and improving the overall robustness of the
training process.

### Changes Made:

- Modified the configuration calculations in megatron_workers.py.

- Added raise ValueError statements to check for the presence of
log_prob_micro_batch_size_per_gpu and ppo_micro_batch_size_per_gpu.
wwwjn pushed a commit to wwwjn/verl that referenced this pull request Jun 10, 2025
…cengine#1700)

### What does this PR do?

 Fix Configuration for Micro Batch Size in Megatron's Ref Policy

### High-Level Design
This pull request addresses an issue with the micro batch size
configuration in the ref policy of Megatron. The default
ppo_megatron_trainer.yaml only includes two configurations:
log_prob_micro_batch_size and log_prob_micro_batch_size_per_gpu.

https://github.com/volcengine/verl/blob/54c9b7364c2d188b2ba4107404cfa3c2b446df19/verl/trainer/config/ppo_megatron_trainer.yaml#L119-L120
However, in `megatron_workers.py`, the required configuration is
ref.log_prob_micro_batch_size_per_gpu

https://github.com/volcengine/verl/blob/54c9b7364c2d188b2ba4107404cfa3c2b446df19/verl/workers/megatron_workers.py#L517-L518
or in `megatron_actor.py ` the required configuration is
ref.ppo_micro_batch_size_per_gpu,

https://github.com/volcengine/verl/blob/54c9b7364c2d188b2ba4107404cfa3c2b446df19/verl/workers/actor/megatron_actor.py#L271-L274

which are not directly related to ppo_micro_batch_size.

To resolve this, I have made modifications to the configuration
calculations and added raise ValueError statements to ensure that the
necessary parameters are correctly defined.

This update ensures that the required parameters are properly handled,
preventing runtime errors and improving the overall robustness of the
training process.

### Changes Made:

- Modified the configuration calculations in megatron_workers.py.

- Added raise ValueError statements to check for the presence of
log_prob_micro_batch_size_per_gpu and ppo_micro_batch_size_per_gpu.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants