-
Notifications
You must be signed in to change notification settings - Fork 2.3k
[feat][BREAKING] Megatron support dynamic batch size, to rebalance the workloads #1617
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
high level comment: same logic appears in both actor, critic, reward and it might be better to re-resign and make it into a standalone module. not urgent for this PR.
verl/utils/seqlen_balancing.py
Outdated
return a + (a % b) | ||
|
||
|
||
def rearrange_micro_batches(batch, max_token_len, dp_group=None, num_batches_devided_by=None, same_micro_num_in_dp=True, min_num_micro_batch=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be a typo? num_batched_devided_by
-> num_batched_divided_by
. there are multiple places appearing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for reviewing~ Maybe I already fixed in previous commit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no. it should be divided
not devided
. previous commit did not fix this. there are multiple places using devided
.
verl/utils/seqlen_balancing.py
Outdated
""" | ||
Split a batch into micro-batches by total token count, with optional DP sync and padding. | ||
|
||
Args: | ||
batch (TensorDict): must include "attention_mask" (B*S); other fields are sliced similarly. | ||
max_token_len (int): max sum of attention_mask per micro-batch. | ||
dp_group (optional): torch.distributed group for data-parallel sync. | ||
vpp_size (optional): virtual pipeline parallel size, for megatron. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is vpp_size
in the rearrange_micro_batches
definition?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I used to pass vpp_size, which is an error, use num_batches_devided_by
instead~
verl/utils/seqlen_balancing.py
Outdated
@@ -212,14 +212,19 @@ def ceildiv(a, b): | |||
return -(a // -b) | |||
|
|||
|
|||
def rearrange_micro_batches(batch, max_token_len, dp_group=None, same_micro_num_in_dp=True, min_num_micro_batch=None): | |||
def roundup_divisible(a, b): | |||
return a + (a % b) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i might think the implantation is not correct. how about the following one? you can double check it.
return ((a + b - 1) // b) * b
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this is correct, I tried to use %
, which has error logic
For the repeat contents, yes, I also feel that there are common logics in these models, we can use abstraction to reduce the repetition. |
…e workloads (volcengine#1617) ### Checklist Before Starting - [x] Search for similar PR(s). ### What does this PR do? 1. Megatron support dynamic batch size, to rebalance the workloads. 2. Fix missing critic metrics. ### High-Level Design Follow the FSDP's dynamic batch size. ### Specific Changes Use the `rearrange_micro_batches` API, but compatible with Megatron VPP constraints. ```py vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size() if vpp_size is not None and vpp_size > 1: microbatch_group_size_per_vp_stage = self.tf_config.microbatch_group_size_per_vp_stage micro_batches, indices = rearrange_micro_batches(batch=mini_batch.batch, num_batches_devided_by=microbatch_group_size_per_vp_stage, max_token_len=max_token_len) assert len(micro_batches) % self.tf_config.microbatch_group_size_per_vp_stage == 0, f"micro_batches {micro_batches} must be divisible by microbatch_group_size_per_vp_stage {microbatch_group_size_per_vp_stage} for megatron backend" else: micro_batches, indices = rearrange_micro_batches(batch=mini_batch.batch, max_token_len=max_token_len) ``` @vermouth1992 please check whether it makes sense. Megatron's constraint when using interleaving pipeline: ```py # If the final micro-batch group has fewer micro-batches than pipeline-parallel size, # the pipeline will have dependency bubbles. final_microbatch_group_size = num_microbatches % config.microbatch_group_size_per_vp_stage if 0 < final_microbatch_group_size < pipeline_parallel_size: msg = 'The remainder of M (the total micro-batches) divided by N (number of ' msg += 'contiguous micro-batches in a virtual pipeline stage) should be 0, ' msg += 'or larger than or equal to the pipeline-parallel size, but it is ' msg += f'{final_microbatch_group_size}. ' msg += 'Otherwise, it introduces dependency bubbles in the pipeline ' msg += 'and reduces throughput.' raise RuntimeError(msg) ``` ### API Megatron forward_backward_batch has changed input, and the output has become a dict, containing original `output` and the `indices` needed for compute_old_log_probs. ### Usage Example ```bash actor_rollout_ref.actor.use_dynamic_bsz=${USE_DYNAMIC_BSZ} \ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${ppo_max_token_len_per_gpu} \ critic.ppo_max_token_len_per_gpu=${forward_max_token_len_per_gpu} \ ``` Other models will directly copy the config. ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluatuion results, etc. ### Additional Info. - **Issue Number**: Fixes issue # or discussion # if any. - **Training**: [Note which backend this PR will affect: FSDP, Megatron, both, or none] - **Inference**: [Note which backend this PR will affect: vLLM, SGLang, both, or none] ### Checklist Before Submitting - [x] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting). - [x] Add `[BREAKING]` to the PR title if it breaks any API. - [ ] Update the documentation about your changes in the [docs](https://github.com/volcengine/verl/tree/main/docs). - [x] Add CI test(s) if necessary.
…e workloads (volcengine#1617) ### Checklist Before Starting - [x] Search for similar PR(s). ### What does this PR do? 1. Megatron support dynamic batch size, to rebalance the workloads. 2. Fix missing critic metrics. ### High-Level Design Follow the FSDP's dynamic batch size. ### Specific Changes Use the `rearrange_micro_batches` API, but compatible with Megatron VPP constraints. ```py vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size() if vpp_size is not None and vpp_size > 1: microbatch_group_size_per_vp_stage = self.tf_config.microbatch_group_size_per_vp_stage micro_batches, indices = rearrange_micro_batches(batch=mini_batch.batch, num_batches_devided_by=microbatch_group_size_per_vp_stage, max_token_len=max_token_len) assert len(micro_batches) % self.tf_config.microbatch_group_size_per_vp_stage == 0, f"micro_batches {micro_batches} must be divisible by microbatch_group_size_per_vp_stage {microbatch_group_size_per_vp_stage} for megatron backend" else: micro_batches, indices = rearrange_micro_batches(batch=mini_batch.batch, max_token_len=max_token_len) ``` @vermouth1992 please check whether it makes sense. Megatron's constraint when using interleaving pipeline: ```py # If the final micro-batch group has fewer micro-batches than pipeline-parallel size, # the pipeline will have dependency bubbles. final_microbatch_group_size = num_microbatches % config.microbatch_group_size_per_vp_stage if 0 < final_microbatch_group_size < pipeline_parallel_size: msg = 'The remainder of M (the total micro-batches) divided by N (number of ' msg += 'contiguous micro-batches in a virtual pipeline stage) should be 0, ' msg += 'or larger than or equal to the pipeline-parallel size, but it is ' msg += f'{final_microbatch_group_size}. ' msg += 'Otherwise, it introduces dependency bubbles in the pipeline ' msg += 'and reduces throughput.' raise RuntimeError(msg) ``` ### API Megatron forward_backward_batch has changed input, and the output has become a dict, containing original `output` and the `indices` needed for compute_old_log_probs. ### Usage Example ```bash actor_rollout_ref.actor.use_dynamic_bsz=${USE_DYNAMIC_BSZ} \ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${ppo_max_token_len_per_gpu} \ critic.ppo_max_token_len_per_gpu=${forward_max_token_len_per_gpu} \ ``` Other models will directly copy the config. ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluatuion results, etc. ### Additional Info. - **Issue Number**: Fixes issue # or discussion # if any. - **Training**: [Note which backend this PR will affect: FSDP, Megatron, both, or none] - **Inference**: [Note which backend this PR will affect: vLLM, SGLang, both, or none] ### Checklist Before Submitting - [x] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting). - [x] Add `[BREAKING]` to the PR title if it breaks any API. - [ ] Update the documentation about your changes in the [docs](https://github.com/volcengine/verl/tree/main/docs). - [x] Add CI test(s) if necessary.
Checklist Before Starting
What does this PR do?
High-Level Design
Follow the FSDP's dynamic batch size.
Specific Changes
Use the
rearrange_micro_batches
API, but compatible with Megatron VPP constraints.@vermouth1992 please check whether it makes sense.
Megatron's constraint when using interleaving pipeline:
API
Megatron forward_backward_batch has changed input, and the output has become a dict, containing original
output
and theindices
needed for compute_old_log_probs.Usage Example
Other models will directly copy the config.
Test
Additional Info.
Checklist Before Submitting
[BREAKING]
to the PR title if it breaks any API.