Skip to content

[RFC] engine interface for training backends (FSDP, FSDP2, torchtitan, Megatron, Mindspore, PAI-Megatron, etc) #1371

@eric-haibin-lin

Description

@eric-haibin-lin

To RFC requesters: Please take a look at previous RFCs for reference.

Motivation

Terms:

  • Roles: actor, critic, reward, etc
  • Training Engines: fsdp1, fsdp2, megatron, etc

Currently, the training engines (Megatron/FSDP) are coupled with RL roles (e.g actor/critic). This makes it hard to write unit test for each module. Ideally for FSDP/Megatron engines, for development we want to have standalone test that simply runs forward/loss/backward/update. To achieve the separation of concerns, the engines better be a standalone module separated from ActorRolloutRefWorker. This way it would be much easier for the community to integrate different training backends, such as internal Megatron forks, torchtitan or other parallel training engines.

On the other hand, in the codebase there's repeated code used to wrap modules with fsdp or megatron for each role (e.g. reward, critic). To achieve DRY, we need to modulize distributed model creation and put FSDP/Megatron model fwd/bwd code into a single class.

Proposed Changes

Refactor actors (e.g. actor.py) agnostic to training engine backends, while training engines expose the same interface to actors.

worker
    engine (Model computation layer, includes various parallelism strategies. Only contains basic operations like forward, backward, and optimizer step. The definition of the loss is in the outermost driver code, passed down via a set function to the actor/critic worker layer in workers/actors, where it’s actually called.)
        megatron
            model.py (Base class for policy/value models. Includes most common functions that rarely need modification: optimizer_step, optimizer_zero_grad, etc.)
        fsdp
            model.py (Same as above)
    actors (A lightweight forwarding layer of the engine, used to unify the functions callable by the driver layer for both training and inference backends. Be cautious when adding new elements. It also serves as the server-facing interface layer — this design can still be refined.)
        actor.py (Unifies FSDP/Megatron/etc. Contains sharding/gather manager, handles some optimization-related context manager work for FSDP like sp group, and in the future, for Megatron — e.g., optimizations requiring context management such as supporting different models and parallel strategies under HE.)
        critic.py (Same as above)
        ref.py (Unifies training backend / xperf / etc.)
        reward.py (Same as above)
        rollout.py (Needs to be unified with xperf/model.py in terms of the build_model interface)

Engine Interfaces

FSDP

class FSDPModel:
# `FSDPModel` Class Summary

## `__init__`
- Initializes configuration.
- Compiles entropy/loss functions with `torch.compile`.
- Sets default loss function and config.
- Builds FSDP mesh for distributed training.

---

## `init_model`
Builds and initializes:
1. Hugging Face model config.
2. Model module with rmpad and gradient checkpointing.
3. Optimizer and LR scheduler.
4. Wraps model in FSDP.

---

## `_build_mesh`
- Builds device mesh

---

## `_build_model`
- Validates config (e.g., SP size vs attention heads).
- Applies monkey patch (rmpad).
- Loads HF model (vision or causal LM).
- Applies:
  - Gradient checkpointing
  - Metrics logging context
- Applies parallel plan (e.g., tensor parallelism).
- Builds FSDP with:
  - Mixed precision
  - Auto wrap policy
  - Sharding strategy (FULL/HYBRID)
  - Optional CPU & activation offload
- Registers hooks for activation offload.

---

## `_build_optimizer`
- Creates optimizer via config.
- Sets up warmup-based LR scheduler.
- Logs memory usage post-initialization.

---

## `optimizer_step`
- Optionally loads optimizer to GPU (if offloaded).
- Clips gradients (`clip_grad_norm_`).
- Steps optimizer.
- Optionally offloads optimizer back.
- **Returns:** `{"grad_norm": grad_norm}`

---

## `optimizer_zero_grad`
- Properly clears gradients under `use_orig_params=True`.
- Ensures `FlatParam.grad` is cleared manually.

---

## `lr_scheduler_step`
- Steps the LR scheduler.
- **Returns:** `{"lr": current_lr}`

---

## `set_loss(loss_fn, loss_config)`
- Sets external loss function and config (typically by driver).

---

## `_forward_micro_batch`
- Preprocesses micro-batch (unpad, SP slicing).
- Executes forward pass:
  - Uses CE fusion if configured and entropy not needed.
  - Optionally clamps logits.
- Computes log-probs, entropy.
- Restores padded output to batch shape.
- **Returns:** `TensorDict({"logprobs", "entropy"}), seqlen_rmpad`

---

## `_backward_step(loss)`
- Backward pass: `loss.backward()`

---

## `_make_micro_batches(data)`
- Splits minibatch into micro-batches:
  - Based on `micro_batch_tokens` or `micro_batch_size`
- **Returns:** `micro_batches, num_micro_batches, batch_indices`

---

## `forward_backward_step(data, forward_only=False)`
- Validates input keys/meta.
- Loops over micro-batches:
  1. Forward pass
  2. Compute loss
  3. Backward pass
- Accumulates and returns metrics.
- **Returns:** `DataProto(output, meta_info={"metrics": ...})`

---

## `get_state_dict()`
- Gets full model state dict using FSDP context.
- Notes memory inefficiency and potential FSDP2 optimization.

Megatron, following the same interface:

class MegatronModel:
    def __init__(
        self,
        role_config,
        model_config,
        is_actor=True
    ):

    def _megatron_model_provider(self, pre_process=True, post_process=True):
        """Build the megatron model."""
        return GPTModel()

    def init_model(self):
        get_model(self._megatron_model_provider, True, **self.kwargs)
    
    def configure_optimizers(self, optimizer_kwargs: DictConfig = None) -> None:
        ...

    def optimizer_step(self, optimizers: List[Optimizer]) -> dict:
        return metrics

    def optimizer_zero_grad(self, optimizers: List[Optimizer]) -> None:
        ...

    def lr_scheduler_step(self, schedulers, **kwargs) -> None:
        if self.megatron_update_successful:
            scheduler.step(increment=1)
    
    def set_loss_fn(self, loss_fn: Callable):
        """
        loss function interface: loss_fn(output, micro_batch) 
        output keys: logprob and entropy
        """
        self.loss_fn = loss_fn

    def forward_backward_step(self, data: DataProto, forward_only=True) -> DataProto:
        """
            base class for forward backward once for a minibatch
            data: DataProto, including all _forward_backward_step needed data
        """
        metrics_ret = forward_backward_func(rmpad_micro_batches, 
                                            forward_func=forward_func)
        output = DataProto(TensorDict(output), meta_info={"metrics", metrics_ret})
        return output

Execution plan:

The following plan would involve gradual changes, less risky in causing regression or conflicts.

  1. apply the refactored interface to SFT for verification:
trainer
    fintune
        sft_trainer (trainer impl)
    fsdp_sft_trainer.py (add deprecation warning and remove in next version)
    main_sft.py (entry script)
workers
    engine
        fsdp
            model.py # FSDPModel

and update all sft example script with main_sft.py

  1. add actor/critic/reward with fsdp engine, deprecate dp_actor/dp_critic/etc
  2. refactor megatron code and add megatron engine to megatron_actor/critic/reward

In theory, once (1) is done, we can enable parallel efforts such as integrating fsdp2+other parallelism to the engine code, or enable engines for non-nvidia GPUs.

Feedback Period

5/2 - 5/9

CC List

Credit to ZR, @vermouth1992 @Frag17 and the verl team
cc @ETOgaosion @ccclyu @wconstab @lxg2015 @mori360 @weifengpy @PeterSH6 @yushengsu-thu @Chendong98 @as12138

Any Other Things

Any other things you would like to mention.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions