-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Description
To RFC requesters: Please take a look at previous RFCs for reference.
Motivation
Terms:
- Roles: actor, critic, reward, etc
- Training Engines: fsdp1, fsdp2, megatron, etc
Currently, the training engines (Megatron/FSDP) are coupled with RL roles (e.g actor/critic). This makes it hard to write unit test for each module. Ideally for FSDP/Megatron engines, for development we want to have standalone test that simply runs forward/loss/backward/update. To achieve the separation of concerns, the engines better be a standalone module separated from ActorRolloutRefWorker
. This way it would be much easier for the community to integrate different training backends, such as internal Megatron forks, torchtitan or other parallel training engines.
On the other hand, in the codebase there's repeated code used to wrap modules with fsdp or megatron for each role (e.g. reward, critic). To achieve DRY, we need to modulize distributed model creation and put FSDP/Megatron model fwd/bwd code into a single class.
Proposed Changes
Refactor actors (e.g. actor.py) agnostic to training engine backends, while training engines expose the same interface to actors.
worker
engine (Model computation layer, includes various parallelism strategies. Only contains basic operations like forward, backward, and optimizer step. The definition of the loss is in the outermost driver code, passed down via a set function to the actor/critic worker layer in workers/actors, where it’s actually called.)
megatron
model.py (Base class for policy/value models. Includes most common functions that rarely need modification: optimizer_step, optimizer_zero_grad, etc.)
fsdp
model.py (Same as above)
actors (A lightweight forwarding layer of the engine, used to unify the functions callable by the driver layer for both training and inference backends. Be cautious when adding new elements. It also serves as the server-facing interface layer — this design can still be refined.)
actor.py (Unifies FSDP/Megatron/etc. Contains sharding/gather manager, handles some optimization-related context manager work for FSDP like sp group, and in the future, for Megatron — e.g., optimizations requiring context management such as supporting different models and parallel strategies under HE.)
critic.py (Same as above)
ref.py (Unifies training backend / xperf / etc.)
reward.py (Same as above)
rollout.py (Needs to be unified with xperf/model.py in terms of the build_model interface)
Engine Interfaces
FSDP
class FSDPModel:
# `FSDPModel` Class Summary
## `__init__`
- Initializes configuration.
- Compiles entropy/loss functions with `torch.compile`.
- Sets default loss function and config.
- Builds FSDP mesh for distributed training.
---
## `init_model`
Builds and initializes:
1. Hugging Face model config.
2. Model module with rmpad and gradient checkpointing.
3. Optimizer and LR scheduler.
4. Wraps model in FSDP.
---
## `_build_mesh`
- Builds device mesh
---
## `_build_model`
- Validates config (e.g., SP size vs attention heads).
- Applies monkey patch (rmpad).
- Loads HF model (vision or causal LM).
- Applies:
- Gradient checkpointing
- Metrics logging context
- Applies parallel plan (e.g., tensor parallelism).
- Builds FSDP with:
- Mixed precision
- Auto wrap policy
- Sharding strategy (FULL/HYBRID)
- Optional CPU & activation offload
- Registers hooks for activation offload.
---
## `_build_optimizer`
- Creates optimizer via config.
- Sets up warmup-based LR scheduler.
- Logs memory usage post-initialization.
---
## `optimizer_step`
- Optionally loads optimizer to GPU (if offloaded).
- Clips gradients (`clip_grad_norm_`).
- Steps optimizer.
- Optionally offloads optimizer back.
- **Returns:** `{"grad_norm": grad_norm}`
---
## `optimizer_zero_grad`
- Properly clears gradients under `use_orig_params=True`.
- Ensures `FlatParam.grad` is cleared manually.
---
## `lr_scheduler_step`
- Steps the LR scheduler.
- **Returns:** `{"lr": current_lr}`
---
## `set_loss(loss_fn, loss_config)`
- Sets external loss function and config (typically by driver).
---
## `_forward_micro_batch`
- Preprocesses micro-batch (unpad, SP slicing).
- Executes forward pass:
- Uses CE fusion if configured and entropy not needed.
- Optionally clamps logits.
- Computes log-probs, entropy.
- Restores padded output to batch shape.
- **Returns:** `TensorDict({"logprobs", "entropy"}), seqlen_rmpad`
---
## `_backward_step(loss)`
- Backward pass: `loss.backward()`
---
## `_make_micro_batches(data)`
- Splits minibatch into micro-batches:
- Based on `micro_batch_tokens` or `micro_batch_size`
- **Returns:** `micro_batches, num_micro_batches, batch_indices`
---
## `forward_backward_step(data, forward_only=False)`
- Validates input keys/meta.
- Loops over micro-batches:
1. Forward pass
2. Compute loss
3. Backward pass
- Accumulates and returns metrics.
- **Returns:** `DataProto(output, meta_info={"metrics": ...})`
---
## `get_state_dict()`
- Gets full model state dict using FSDP context.
- Notes memory inefficiency and potential FSDP2 optimization.
Megatron, following the same interface:
class MegatronModel:
def __init__(
self,
role_config,
model_config,
is_actor=True
):
def _megatron_model_provider(self, pre_process=True, post_process=True):
"""Build the megatron model."""
return GPTModel()
def init_model(self):
get_model(self._megatron_model_provider, True, **self.kwargs)
def configure_optimizers(self, optimizer_kwargs: DictConfig = None) -> None:
...
def optimizer_step(self, optimizers: List[Optimizer]) -> dict:
return metrics
def optimizer_zero_grad(self, optimizers: List[Optimizer]) -> None:
...
def lr_scheduler_step(self, schedulers, **kwargs) -> None:
if self.megatron_update_successful:
scheduler.step(increment=1)
def set_loss_fn(self, loss_fn: Callable):
"""
loss function interface: loss_fn(output, micro_batch)
output keys: logprob and entropy
"""
self.loss_fn = loss_fn
def forward_backward_step(self, data: DataProto, forward_only=True) -> DataProto:
"""
base class for forward backward once for a minibatch
data: DataProto, including all _forward_backward_step needed data
"""
metrics_ret = forward_backward_func(rmpad_micro_batches,
forward_func=forward_func)
output = DataProto(TensorDict(output), meta_info={"metrics", metrics_ret})
return output
Execution plan:
The following plan would involve gradual changes, less risky in causing regression or conflicts.
- apply the refactored interface to SFT for verification:
trainer
fintune
sft_trainer (trainer impl)
fsdp_sft_trainer.py (add deprecation warning and remove in next version)
main_sft.py (entry script)
workers
engine
fsdp
model.py # FSDPModel
and update all sft example script with main_sft.py
- add actor/critic/reward with fsdp engine, deprecate dp_actor/dp_critic/etc
- refactor megatron code and add megatron engine to megatron_actor/critic/reward
In theory, once (1) is done, we can enable parallel efforts such as integrating fsdp2+other parallelism to the engine code, or enable engines for non-nvidia GPUs.
Feedback Period
5/2 - 5/9
CC List
Credit to ZR, @vermouth1992 @Frag17 and the verl team
cc @ETOgaosion @ccclyu @wconstab @lxg2015 @mori360 @weifengpy @PeterSH6 @yushengsu-thu @Chendong98 @as12138
Any Other Things
Any other things you would like to mention.