Skip to content

Conversation

ZihengJiang
Copy link
Collaborator

@ZihengJiang ZihengJiang commented Jun 11, 2025

[Refactor] Training Engine Interface and Development Plan

Motivation

See the original RFC for background: #1371

Modernizing our training loop requires that we:

  • Decouple training-backend implementation from algorithm code so each can evolve independently
  • Unify on a single, well-defined Engine interface across FSDP/Megatron/etc backends
  • Enable unit-testing of each backend implementation in isolation
  • Guarantee algorithm “roles” (Critic, Actor, Rollout, Ref) remain completely engine-agnostic.

Current Implementation

This PR:

  • Introduces an abstract BaseEngine class that defines a unified training‐engine interface.
  • Implements FSDPEngine, a concrete BaseEngine using PyTorch FullyShardedDataParallel.
  • Provides a CriticWorker based on FSDPEngine that plugs seamlessly into existing PPO training code without any changes.

Classic Training Loop with the New Interface

# 1. Build and initialize engine
engine = FSDPEngine(config)
engine.init_model()
engine.set_loss_fn(loss_fn)

# 2. Training loop
for epoch in range(config.num_epochs):
    for batch in train_loader:
        # a) zero gradients
        engine.optimizer_zero_grad()

        # b) forward + backward
        with engine.train_mode():
            preds, loss, ctx = engine.forward_backward_step(
                batch,
                ctx,
                forward_only=False,
                preprocess_fn=preprocess_fn,
                postprocess_fn=postprocess_fn
            )

        # c) update and schedule
        grad_norm = engine.optimizer_step()
        current_lr = engine.lr_scheduler_step()

# 3. Evaluation
with engine.eval_mode():
    for micro_batch in data:
        preds, ctx = engine.forward_backward_step(
            micro_batch,
            ctx,
            forward_only=True,
            preprocess_fn=preprocess_fn,
            postprocess_fn=postprocess_fn
        )

Detailed BaseEngine Interface

We now introduce an abstract base class, BaseEngine, which defines our unified training-engine interface.

Key enhancements over the original RFC:

  • train_mode() / eval_mode()
    Context managers to control parameter and activation load/offload at the start and end of each loop.
  • shard_data() / unshard_data()
    APIs for partitioning and gathering data across devices or workers.
  • preprocess_fn / postprocess_fn in forward_backward_step()
    Hooks to apply custom transformations before and after each micro-batch pass.

Below are the detailed signatures for each core method.

class BaseEngine(object):
    """
    Abstract base class defining the interface for model training engines.

    Engine implementations must subclass BaseEngine and provide concrete behavior for all methods.
    """
    def __init__(self, config):
        """
        Initialize the BaseEngine.

        Args:
            config: Configuration object containing parameters for engine setup.
        """
        raise NotImplementedError

    def init_model(self):
        """
        Instantiate or load the model, optimizer, and learning rate scheduler.

        Should prepare all components necessary for training or evaluation.
        """
        raise NotImplementedError

    def train_mode(self):
        """
        Context manager entry for switching the engine and model into training mode.

        Usage:
            with engine.train_mode():
                # runs in training mode
        """
        raise NotImplementedError

    def eval_mode(self):        
        """
        Context manager entry for switching the engine and model into evaluation mode.

        Usage:
            with engine.eval_mode():
                # runs in evaluation mode
        """
        raise NotImplementedError

    def forward_backward_step(self, 
                              batch, 
                              ctx=None, 
                              forward_only=False, 
                              preprocess_fn=None, 
                              postprocess_fn=None):
        """
        Execute a forward pass (and optional backward pass) over a batch of data.

        Args:
            batch: Raw batch data (e.g., tensors or mappings) to process.
            ctx: Optional context dict passed to preprocess/postprocess functions.
            forward_only: If True, skip gradient computation and backward pass.
            preprocess_fn: Function(batch, ctx) -> (inputs, ctx), applied before model call.
            postprocess_fn: Function(outputs, ctx) -> (predictions, ctx), applied after model call.

        Returns:
            If forward_only:
                (predictions, ctx)
            Else:
                (predictions, loss, ctx)
        """
        raise NotImplementedError

    def optimizer_zero_grad(self):
        """
        Zero out gradients of all parameters before starting a new backward pass.
        """
        raise NotImplementedError

    def optimizer_step(self):
        """
        Perform an optimization step to update model parameters based on accumulated gradients.

        Returns:
            grad_norm (float): The norm of the gradients before clipping or update.
        """
        raise NotImplementedError

    def lr_scheduler_step(self):
        """
        Advance the learning rate scheduler by one step.

        Returns:
            current_lr (float or list[float]): Updated learning rate(s).
        """
        raise NotImplementedError

    def shard_data(self, data):
        """
        Shard or partition data for distributed training or parallel execution.

        Args:
            data: Data structure to be sharded across devices/workers.

        Returns:
            Sharded data in the same format as input.
        """
        raise NotImplementedError

    def unshard_data(self, data):
        """
        Reconstruct or gather sharded data back to a unified format.

        Args:
            data: Sharded data structure to reconstruct.

        Returns:
            Unsharded, combined data.
        """
        raise NotImplementedError
        

    def set_loss_fn(self, loss_fn):
        """
        Set the loss function to be used during training.

        Args:
            loss_fn: Callable(data, predictions, ctx) -> (loss_tensor, new_ctx)
        """
        raise NotImplementedError

    def to(self, device: str, model: bool = True, optimizer: bool = True):
        """
        Move model parameters, optimizer states, or both to the specified device.

        Args:
            device: Target device identifier (e.g., "cuda" or "cpu").
            model: If True, move the model.
            optimizer: If True, move the optimizer states.
        """
        raise NotImplementedError


    def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
        """
        Save model, optimizer, and scheduler states to a checkpoint.

        Args:
            local_path: Local filesystem path to save checkpoint.
            hdfs_path: Optional HDFS path to copy checkpoint.
            global_step: Integer training step number for naming.
            max_ckpt_to_keep: Maximum number of recent checkpoints to retain.
        """
        raise NotImplementedError


    def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True):
        """
        Load model, optimizer, and scheduler states from a checkpoint.

        Args:
            local_path: Local filesystem path of the checkpoint.
            hdfs_path: Optional HDFS path where checkpoint is stored.
            del_local_after_load: Whether to delete local copy after loading.
        """
        raise NotImplementedError

FSDPEngine Implementaion

A concrete FSDPEngine implements all methods using PyTorch FullyShardedDataParallel, supporting all the features that FSDP DPCritic Worker support:

  • Multi-GPU/model sharding
  • Activation- and optimizer-offload
  • LoRA & sequence parallelism
  • Dynamic batch size and remove padding

CriticWorker Implementation based on the FSDPEngine

  • Unchanged public API
  • Each role calls only BaseEngine methods (init_model, train_mode/eval_mode, forward_backward_step, etc.)
  • No modifications needed in existing algorithms (e.g., PPOTraining)
  • New roles can be plugged in identically to legacy code

Development Plan

We’ll roll this out in three gated phases, controlled by a feature-flag (use_legacy_worker_impl).

Phase 1: Engine Development

Flag: use_legacy_worker_impl = True (default)
New interface under active development

  • Refactor Critic, Actor, Rollout, Ref to use only BaseEngine APIs
  • Design a hierarchical, immutable config system for engine/backends
  • Ensure PPO training curves and final accuracy match legacy implementation

Phase 2: Migration

Flag: use_legacy_worker_impl = False (default) – legacy path logs a deprecation warning
All new code targets the new interface; 2–3 months of integration/stress testing

  • Enforce new interface for all feature work
  • Gather benchmarks, bug reports, and performance data

Phase 3: Cleanup

After Phase 2 validation:

  • Remove legacy worker code and flags
  • Finalize documentation, update changelogs, close deprecation notices

Please review this refactor and share any feedback or concerns! Contributions are welcome.

@CLAassistant
Copy link

CLAassistant commented Jun 11, 2025

CLA assistant check
All committers have signed the CLA.

@vermouth1992
Copy link
Collaborator

vermouth1992 commented Jun 12, 2025

@ccclyu Please given some feedbacks as the same process will need for MegatronWorker. The current design philosophy is

  • Megatron and FSDP will use the same worker with only difference in the ModelEngine. (There are some issues to be solved with this because Megatron and FSDP must have different data dispatch mode). This means that we have to choose different data dispatch mode given the computation backend of CriticWorker
  • Keep the legacy CriticWorker and new CriticWorker for some time and remove the legacy CriticWorker. There are several issues: 1) the CI pressure doubles. 2) how to maintain consistency when new features are added? 3) how to handle PRs that contribute to the legacy worker?

@ZihengJiang
Copy link
Collaborator Author

@vermouth1992 @eric-haibin-lin @PeterSH6 @tongyx361 @ETOgaosion @hongpeng-guo @wwwjn @tianyu-l
Please review this refactor and share any feedback or concerns!

@vermouth1992
Copy link
Collaborator

vermouth1992 commented Jun 12, 2025

Checklist

  • FSDP self-contain: whether actor/ref/critic/reward model can be easily implemented using this API @vermouth1992 @PeterSH6
  • Whether lora can be implemented using this API
  • Whether VLM can be easily implemented using this API @hiyouga
  • Whether Megatron-LM can use the ModelEngine design @ISEEKYAN
  • Whether TorchTitan can use the ModelEngine design
  • Whether MindSpeed can use the ModelEngine design
  • How to handle data dispatch problem? Currently, it's mixed. single controller dispatch + internal resharding. @vermouth1992

@vadimkantorov
Copy link

vadimkantorov commented Jun 16, 2025

using PyTorch FullyShardedDataParallel.

Meaning FSDP1?

@ZihengJiang
Copy link
Collaborator Author

using PyTorch FullyShardedDataParallel.

Meaning FSDP1?

I also migrated the FSDP2 implementation in the PR.

elif config.strategy == "fsdp2":

Copy link
Contributor

@ISEEKYAN ISEEKYAN left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, I have some comments from aspect of megatron compatibility.

Comment on lines 78 to 66
responses = batch["responses"]
attention_mask = batch["attention_mask"]
values = batch["values"]
returns = batch["returns"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have a protocol for batch?

@ZihengJiang
Copy link
Collaborator Author

@vermouth1992 @ISEEKYAN made several updates:

  • adapt the forward-backward-step interface for mini-batch instead of microbatch
  • move ulysses config into engine implementation and left microbatch process logic also to the engine.

Let me know whether current interface looks good to you guys.

@eranhirs
Copy link

eranhirs commented Jul 7, 2025

Just putting here some feedback, it seems like the right place: it would be useful to be able to run non-hybrid engines, which currently throws NotImplementedError (see here). One motivation for this is #1049.

@ZihengJiang ZihengJiang force-pushed the ziheng/dev-0610 branch 2 times, most recently from 1f606af to f98df19 Compare July 8, 2025 23:57
@ZihengJiang ZihengJiang changed the title WIP: [Refactor] Training Engine Interface and Development Plan [Refactor] Training Engine Interface and Development Plan Jul 8, 2025
@ZihengJiang ZihengJiang changed the title [Refactor] Training Engine Interface and Development Plan [trainer] refactor: Training Engine Interface and Development Plan Jul 8, 2025
@ZihengJiang ZihengJiang marked this pull request as ready for review July 9, 2025 00:09
output = output.to("cpu")
return output

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will need to decide how we dispatch here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you elaborate on this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we'll work on that for megatron backend

Supports model sharding, activation/optimizer offloading, LoRA, and sequence parallelism.
"""

def __init__(self, config):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we define a dataclass for this config?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And convert hydra config to dataclass config at worker level

Copy link
Collaborator

@eric-haibin-lin eric-haibin-lin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

@eric-haibin-lin eric-haibin-lin merged commit 9d7cba4 into volcengine:main Jul 18, 2025
61 of 63 checks passed
eric-haibin-lin pushed a commit to eric-haibin-lin/verl that referenced this pull request Jul 19, 2025
…olcengine#1977)

# [Refactor] Training Engine Interface and Development Plan

## Motivation  
See the original RFC for background:
volcengine#1371

Modernizing our training loop requires that we:

- **Decouple** training-backend implementation from algorithm code so
each can evolve independently
- **Unify** on a single, well-defined `Engine` interface across
FSDP/Megatron/etc backends
- **Enable** unit-testing of each backend implementation in isolation  
- **Guarantee** algorithm “roles” (Critic, Actor, Rollout, Ref) remain
completely engine-agnostic.

---

## Current Implementation  

This PR:
- Introduces an abstract `BaseEngine` class that defines a unified
training‐engine interface.
- Implements `FSDPEngine`, a concrete `BaseEngine` using PyTorch
FullyShardedDataParallel.
- Provides a `CriticWorker` based on `FSDPEngine` that plugs seamlessly
into existing PPO training code without any changes.


### Classic Training Loop with the New Interface

```python
# 1. Build and initialize engine
engine = FSDPEngine(config)
engine.init_model()
engine.set_loss_fn(loss_fn)

# 2. Training loop
for epoch in range(config.num_epochs):
    for batch in train_loader:
        # a) zero gradients
        engine.optimizer_zero_grad()

        # b) forward + backward
        with engine.train_mode():
            preds, loss, ctx = engine.forward_backward_step(
                batch,
                ctx,
                forward_only=False,
                preprocess_fn=preprocess_fn,
                postprocess_fn=postprocess_fn
            )

        # c) update and schedule
        grad_norm = engine.optimizer_step()
        current_lr = engine.lr_scheduler_step()

# 3. Evaluation
with engine.eval_mode():
    for micro_batch in data:
        preds, ctx = engine.forward_backward_step(
            micro_batch,
            ctx,
            forward_only=True,
            preprocess_fn=preprocess_fn,
            postprocess_fn=postprocess_fn
        )
```

### Detailed BaseEngine Interface
We now introduce an abstract base class, `BaseEngine`, which defines our
unified training-engine interface.

**Key enhancements over the original RFC:**
- **`train_mode()` / `eval_mode()`**  
Context managers to control parameter and activation load/offload at the
start and end of each loop.
- **`shard_data()` / `unshard_data()`**  
  APIs for partitioning and gathering data across devices or workers.  
- **`preprocess_fn` / `postprocess_fn` in `forward_backward_step()`**  
Hooks to apply custom transformations before and after each micro-batch
pass.

Below are the detailed signatures for each core method.

```python

class BaseEngine(object):
    """
    Abstract base class defining the interface for model training engines.

    Engine implementations must subclass BaseEngine and provide concrete behavior for all methods.
    """
    def __init__(self, config):
        """
        Initialize the BaseEngine.

        Args:
            config: Configuration object containing parameters for engine setup.
        """
        raise NotImplementedError

    def init_model(self):
        """
        Instantiate or load the model, optimizer, and learning rate scheduler.

        Should prepare all components necessary for training or evaluation.
        """
        raise NotImplementedError

    def train_mode(self):
        """
        Context manager entry for switching the engine and model into training mode.

        Usage:
            with engine.train_mode():
                # runs in training mode
        """
        raise NotImplementedError

    def eval_mode(self):        
        """
        Context manager entry for switching the engine and model into evaluation mode.

        Usage:
            with engine.eval_mode():
                # runs in evaluation mode
        """
        raise NotImplementedError

    def forward_backward_step(self, 
                              batch, 
                              ctx=None, 
                              forward_only=False, 
                              preprocess_fn=None, 
                              postprocess_fn=None):
        """
        Execute a forward pass (and optional backward pass) over a batch of data.

        Args:
            batch: Raw batch data (e.g., tensors or mappings) to process.
            ctx: Optional context dict passed to preprocess/postprocess functions.
            forward_only: If True, skip gradient computation and backward pass.
            preprocess_fn: Function(batch, ctx) -> (inputs, ctx), applied before model call.
            postprocess_fn: Function(outputs, ctx) -> (predictions, ctx), applied after model call.

        Returns:
            If forward_only:
                (predictions, ctx)
            Else:
                (predictions, loss, ctx)
        """
        raise NotImplementedError

    def optimizer_zero_grad(self):
        """
        Zero out gradients of all parameters before starting a new backward pass.
        """
        raise NotImplementedError

    def optimizer_step(self):
        """
        Perform an optimization step to update model parameters based on accumulated gradients.

        Returns:
            grad_norm (float): The norm of the gradients before clipping or update.
        """
        raise NotImplementedError

    def lr_scheduler_step(self):
        """
        Advance the learning rate scheduler by one step.

        Returns:
            current_lr (float or list[float]): Updated learning rate(s).
        """
        raise NotImplementedError

    def shard_data(self, data):
        """
        Shard or partition data for distributed training or parallel execution.

        Args:
            data: Data structure to be sharded across devices/workers.

        Returns:
            Sharded data in the same format as input.
        """
        raise NotImplementedError

    def unshard_data(self, data):
        """
        Reconstruct or gather sharded data back to a unified format.

        Args:
            data: Sharded data structure to reconstruct.

        Returns:
            Unsharded, combined data.
        """
        raise NotImplementedError
        

    def set_loss_fn(self, loss_fn):
        """
        Set the loss function to be used during training.

        Args:
            loss_fn: Callable(data, predictions, ctx) -> (loss_tensor, new_ctx)
        """
        raise NotImplementedError

    def to(self, device: str, model: bool = True, optimizer: bool = True):
        """
        Move model parameters, optimizer states, or both to the specified device.

        Args:
            device: Target device identifier (e.g., "cuda" or "cpu").
            model: If True, move the model.
            optimizer: If True, move the optimizer states.
        """
        raise NotImplementedError


    def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
        """
        Save model, optimizer, and scheduler states to a checkpoint.

        Args:
            local_path: Local filesystem path to save checkpoint.
            hdfs_path: Optional HDFS path to copy checkpoint.
            global_step: Integer training step number for naming.
            max_ckpt_to_keep: Maximum number of recent checkpoints to retain.
        """
        raise NotImplementedError


    def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True):
        """
        Load model, optimizer, and scheduler states from a checkpoint.

        Args:
            local_path: Local filesystem path of the checkpoint.
            hdfs_path: Optional HDFS path where checkpoint is stored.
            del_local_after_load: Whether to delete local copy after loading.
        """
        raise NotImplementedError
```

### FSDPEngine Implementaion

A concrete `FSDPEngine` implements all methods using PyTorch
FullyShardedDataParallel, supporting all the features that FSDP DPCritic
Worker support:

- Multi-GPU/model sharding  
- Activation- and optimizer-offload  
- LoRA & sequence parallelism  
- Dynamic batch size and remove padding

### CriticWorker Implementation based on the FSDPEngine
- Unchanged public API 
- Each role calls only BaseEngine methods (init_model,
train_mode/eval_mode, forward_backward_step, etc.)
- No modifications needed in existing algorithms (e.g., PPOTraining)
- New roles can be plugged in identically to legacy code

## Development Plan
We’ll roll this out in three gated phases, controlled by a feature-flag
(`use_legacy_worker_impl`).

### Phase 1: Engine Development
> Flag: use_legacy_worker_impl = True (default)
> New interface under active development

- Refactor Critic, Actor, Rollout, Ref to use only BaseEngine APIs
- Design a hierarchical, immutable config system for engine/backends
- Ensure PPO training curves and final accuracy match legacy
implementation

### Phase 2: Migration
> Flag: use_legacy_worker_impl = False (default) – legacy path logs a
deprecation warning
> All new code targets the new interface; 2–3 months of
integration/stress testing

- Enforce new interface for all feature work
- Gather benchmarks, bug reports, and performance data

### Phase 3: Cleanup
> After Phase 2 validation:
- Remove legacy worker code and flags
- Finalize documentation, update changelogs, close deprecation notices

Please review this refactor and share any feedback or concerns!
Contributions are welcome.
yellowbee686 pushed a commit to yellowbee686/verl that referenced this pull request Jul 25, 2025
…olcengine#1977)

# [Refactor] Training Engine Interface and Development Plan

## Motivation  
See the original RFC for background:
volcengine#1371

Modernizing our training loop requires that we:

- **Decouple** training-backend implementation from algorithm code so
each can evolve independently
- **Unify** on a single, well-defined `Engine` interface across
FSDP/Megatron/etc backends
- **Enable** unit-testing of each backend implementation in isolation  
- **Guarantee** algorithm “roles” (Critic, Actor, Rollout, Ref) remain
completely engine-agnostic.

---

## Current Implementation  

This PR:
- Introduces an abstract `BaseEngine` class that defines a unified
training‐engine interface.
- Implements `FSDPEngine`, a concrete `BaseEngine` using PyTorch
FullyShardedDataParallel.
- Provides a `CriticWorker` based on `FSDPEngine` that plugs seamlessly
into existing PPO training code without any changes.


### Classic Training Loop with the New Interface

```python
# 1. Build and initialize engine
engine = FSDPEngine(config)
engine.init_model()
engine.set_loss_fn(loss_fn)

# 2. Training loop
for epoch in range(config.num_epochs):
    for batch in train_loader:
        # a) zero gradients
        engine.optimizer_zero_grad()

        # b) forward + backward
        with engine.train_mode():
            preds, loss, ctx = engine.forward_backward_step(
                batch,
                ctx,
                forward_only=False,
                preprocess_fn=preprocess_fn,
                postprocess_fn=postprocess_fn
            )

        # c) update and schedule
        grad_norm = engine.optimizer_step()
        current_lr = engine.lr_scheduler_step()

# 3. Evaluation
with engine.eval_mode():
    for micro_batch in data:
        preds, ctx = engine.forward_backward_step(
            micro_batch,
            ctx,
            forward_only=True,
            preprocess_fn=preprocess_fn,
            postprocess_fn=postprocess_fn
        )
```

### Detailed BaseEngine Interface
We now introduce an abstract base class, `BaseEngine`, which defines our
unified training-engine interface.

**Key enhancements over the original RFC:**
- **`train_mode()` / `eval_mode()`**  
Context managers to control parameter and activation load/offload at the
start and end of each loop.
- **`shard_data()` / `unshard_data()`**  
  APIs for partitioning and gathering data across devices or workers.  
- **`preprocess_fn` / `postprocess_fn` in `forward_backward_step()`**  
Hooks to apply custom transformations before and after each micro-batch
pass.

Below are the detailed signatures for each core method.

```python

class BaseEngine(object):
    """
    Abstract base class defining the interface for model training engines.

    Engine implementations must subclass BaseEngine and provide concrete behavior for all methods.
    """
    def __init__(self, config):
        """
        Initialize the BaseEngine.

        Args:
            config: Configuration object containing parameters for engine setup.
        """
        raise NotImplementedError

    def init_model(self):
        """
        Instantiate or load the model, optimizer, and learning rate scheduler.

        Should prepare all components necessary for training or evaluation.
        """
        raise NotImplementedError

    def train_mode(self):
        """
        Context manager entry for switching the engine and model into training mode.

        Usage:
            with engine.train_mode():
                # runs in training mode
        """
        raise NotImplementedError

    def eval_mode(self):        
        """
        Context manager entry for switching the engine and model into evaluation mode.

        Usage:
            with engine.eval_mode():
                # runs in evaluation mode
        """
        raise NotImplementedError

    def forward_backward_step(self, 
                              batch, 
                              ctx=None, 
                              forward_only=False, 
                              preprocess_fn=None, 
                              postprocess_fn=None):
        """
        Execute a forward pass (and optional backward pass) over a batch of data.

        Args:
            batch: Raw batch data (e.g., tensors or mappings) to process.
            ctx: Optional context dict passed to preprocess/postprocess functions.
            forward_only: If True, skip gradient computation and backward pass.
            preprocess_fn: Function(batch, ctx) -> (inputs, ctx), applied before model call.
            postprocess_fn: Function(outputs, ctx) -> (predictions, ctx), applied after model call.

        Returns:
            If forward_only:
                (predictions, ctx)
            Else:
                (predictions, loss, ctx)
        """
        raise NotImplementedError

    def optimizer_zero_grad(self):
        """
        Zero out gradients of all parameters before starting a new backward pass.
        """
        raise NotImplementedError

    def optimizer_step(self):
        """
        Perform an optimization step to update model parameters based on accumulated gradients.

        Returns:
            grad_norm (float): The norm of the gradients before clipping or update.
        """
        raise NotImplementedError

    def lr_scheduler_step(self):
        """
        Advance the learning rate scheduler by one step.

        Returns:
            current_lr (float or list[float]): Updated learning rate(s).
        """
        raise NotImplementedError

    def shard_data(self, data):
        """
        Shard or partition data for distributed training or parallel execution.

        Args:
            data: Data structure to be sharded across devices/workers.

        Returns:
            Sharded data in the same format as input.
        """
        raise NotImplementedError

    def unshard_data(self, data):
        """
        Reconstruct or gather sharded data back to a unified format.

        Args:
            data: Sharded data structure to reconstruct.

        Returns:
            Unsharded, combined data.
        """
        raise NotImplementedError
        

    def set_loss_fn(self, loss_fn):
        """
        Set the loss function to be used during training.

        Args:
            loss_fn: Callable(data, predictions, ctx) -> (loss_tensor, new_ctx)
        """
        raise NotImplementedError

    def to(self, device: str, model: bool = True, optimizer: bool = True):
        """
        Move model parameters, optimizer states, or both to the specified device.

        Args:
            device: Target device identifier (e.g., "cuda" or "cpu").
            model: If True, move the model.
            optimizer: If True, move the optimizer states.
        """
        raise NotImplementedError


    def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
        """
        Save model, optimizer, and scheduler states to a checkpoint.

        Args:
            local_path: Local filesystem path to save checkpoint.
            hdfs_path: Optional HDFS path to copy checkpoint.
            global_step: Integer training step number for naming.
            max_ckpt_to_keep: Maximum number of recent checkpoints to retain.
        """
        raise NotImplementedError


    def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True):
        """
        Load model, optimizer, and scheduler states from a checkpoint.

        Args:
            local_path: Local filesystem path of the checkpoint.
            hdfs_path: Optional HDFS path where checkpoint is stored.
            del_local_after_load: Whether to delete local copy after loading.
        """
        raise NotImplementedError
```

### FSDPEngine Implementaion

A concrete `FSDPEngine` implements all methods using PyTorch
FullyShardedDataParallel, supporting all the features that FSDP DPCritic
Worker support:

- Multi-GPU/model sharding  
- Activation- and optimizer-offload  
- LoRA & sequence parallelism  
- Dynamic batch size and remove padding

### CriticWorker Implementation based on the FSDPEngine
- Unchanged public API 
- Each role calls only BaseEngine methods (init_model,
train_mode/eval_mode, forward_backward_step, etc.)
- No modifications needed in existing algorithms (e.g., PPOTraining)
- New roles can be plugged in identically to legacy code

## Development Plan
We’ll roll this out in three gated phases, controlled by a feature-flag
(`use_legacy_worker_impl`).

### Phase 1: Engine Development
> Flag: use_legacy_worker_impl = True (default)
> New interface under active development

- Refactor Critic, Actor, Rollout, Ref to use only BaseEngine APIs
- Design a hierarchical, immutable config system for engine/backends
- Ensure PPO training curves and final accuracy match legacy
implementation

### Phase 2: Migration
> Flag: use_legacy_worker_impl = False (default) – legacy path logs a
deprecation warning
> All new code targets the new interface; 2–3 months of
integration/stress testing

- Enforce new interface for all feature work
- Gather benchmarks, bug reports, and performance data

### Phase 3: Cleanup
> After Phase 2 validation:
- Remove legacy worker code and flags
- Finalize documentation, update changelogs, close deprecation notices

Please review this refactor and share any feedback or concerns!
Contributions are welcome.
oseyosey pushed a commit to oseyosey/verl that referenced this pull request Jul 28, 2025
…olcengine#1977)

# [Refactor] Training Engine Interface and Development Plan

## Motivation  
See the original RFC for background:
volcengine#1371

Modernizing our training loop requires that we:

- **Decouple** training-backend implementation from algorithm code so
each can evolve independently
- **Unify** on a single, well-defined `Engine` interface across
FSDP/Megatron/etc backends
- **Enable** unit-testing of each backend implementation in isolation  
- **Guarantee** algorithm “roles” (Critic, Actor, Rollout, Ref) remain
completely engine-agnostic.

---

## Current Implementation  

This PR:
- Introduces an abstract `BaseEngine` class that defines a unified
training‐engine interface.
- Implements `FSDPEngine`, a concrete `BaseEngine` using PyTorch
FullyShardedDataParallel.
- Provides a `CriticWorker` based on `FSDPEngine` that plugs seamlessly
into existing PPO training code without any changes.


### Classic Training Loop with the New Interface

```python
# 1. Build and initialize engine
engine = FSDPEngine(config)
engine.init_model()
engine.set_loss_fn(loss_fn)

# 2. Training loop
for epoch in range(config.num_epochs):
    for batch in train_loader:
        # a) zero gradients
        engine.optimizer_zero_grad()

        # b) forward + backward
        with engine.train_mode():
            preds, loss, ctx = engine.forward_backward_step(
                batch,
                ctx,
                forward_only=False,
                preprocess_fn=preprocess_fn,
                postprocess_fn=postprocess_fn
            )

        # c) update and schedule
        grad_norm = engine.optimizer_step()
        current_lr = engine.lr_scheduler_step()

# 3. Evaluation
with engine.eval_mode():
    for micro_batch in data:
        preds, ctx = engine.forward_backward_step(
            micro_batch,
            ctx,
            forward_only=True,
            preprocess_fn=preprocess_fn,
            postprocess_fn=postprocess_fn
        )
```

### Detailed BaseEngine Interface
We now introduce an abstract base class, `BaseEngine`, which defines our
unified training-engine interface.

**Key enhancements over the original RFC:**
- **`train_mode()` / `eval_mode()`**  
Context managers to control parameter and activation load/offload at the
start and end of each loop.
- **`shard_data()` / `unshard_data()`**  
  APIs for partitioning and gathering data across devices or workers.  
- **`preprocess_fn` / `postprocess_fn` in `forward_backward_step()`**  
Hooks to apply custom transformations before and after each micro-batch
pass.

Below are the detailed signatures for each core method.

```python

class BaseEngine(object):
    """
    Abstract base class defining the interface for model training engines.

    Engine implementations must subclass BaseEngine and provide concrete behavior for all methods.
    """
    def __init__(self, config):
        """
        Initialize the BaseEngine.

        Args:
            config: Configuration object containing parameters for engine setup.
        """
        raise NotImplementedError

    def init_model(self):
        """
        Instantiate or load the model, optimizer, and learning rate scheduler.

        Should prepare all components necessary for training or evaluation.
        """
        raise NotImplementedError

    def train_mode(self):
        """
        Context manager entry for switching the engine and model into training mode.

        Usage:
            with engine.train_mode():
                # runs in training mode
        """
        raise NotImplementedError

    def eval_mode(self):        
        """
        Context manager entry for switching the engine and model into evaluation mode.

        Usage:
            with engine.eval_mode():
                # runs in evaluation mode
        """
        raise NotImplementedError

    def forward_backward_step(self, 
                              batch, 
                              ctx=None, 
                              forward_only=False, 
                              preprocess_fn=None, 
                              postprocess_fn=None):
        """
        Execute a forward pass (and optional backward pass) over a batch of data.

        Args:
            batch: Raw batch data (e.g., tensors or mappings) to process.
            ctx: Optional context dict passed to preprocess/postprocess functions.
            forward_only: If True, skip gradient computation and backward pass.
            preprocess_fn: Function(batch, ctx) -> (inputs, ctx), applied before model call.
            postprocess_fn: Function(outputs, ctx) -> (predictions, ctx), applied after model call.

        Returns:
            If forward_only:
                (predictions, ctx)
            Else:
                (predictions, loss, ctx)
        """
        raise NotImplementedError

    def optimizer_zero_grad(self):
        """
        Zero out gradients of all parameters before starting a new backward pass.
        """
        raise NotImplementedError

    def optimizer_step(self):
        """
        Perform an optimization step to update model parameters based on accumulated gradients.

        Returns:
            grad_norm (float): The norm of the gradients before clipping or update.
        """
        raise NotImplementedError

    def lr_scheduler_step(self):
        """
        Advance the learning rate scheduler by one step.

        Returns:
            current_lr (float or list[float]): Updated learning rate(s).
        """
        raise NotImplementedError

    def shard_data(self, data):
        """
        Shard or partition data for distributed training or parallel execution.

        Args:
            data: Data structure to be sharded across devices/workers.

        Returns:
            Sharded data in the same format as input.
        """
        raise NotImplementedError

    def unshard_data(self, data):
        """
        Reconstruct or gather sharded data back to a unified format.

        Args:
            data: Sharded data structure to reconstruct.

        Returns:
            Unsharded, combined data.
        """
        raise NotImplementedError
        

    def set_loss_fn(self, loss_fn):
        """
        Set the loss function to be used during training.

        Args:
            loss_fn: Callable(data, predictions, ctx) -> (loss_tensor, new_ctx)
        """
        raise NotImplementedError

    def to(self, device: str, model: bool = True, optimizer: bool = True):
        """
        Move model parameters, optimizer states, or both to the specified device.

        Args:
            device: Target device identifier (e.g., "cuda" or "cpu").
            model: If True, move the model.
            optimizer: If True, move the optimizer states.
        """
        raise NotImplementedError


    def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
        """
        Save model, optimizer, and scheduler states to a checkpoint.

        Args:
            local_path: Local filesystem path to save checkpoint.
            hdfs_path: Optional HDFS path to copy checkpoint.
            global_step: Integer training step number for naming.
            max_ckpt_to_keep: Maximum number of recent checkpoints to retain.
        """
        raise NotImplementedError


    def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True):
        """
        Load model, optimizer, and scheduler states from a checkpoint.

        Args:
            local_path: Local filesystem path of the checkpoint.
            hdfs_path: Optional HDFS path where checkpoint is stored.
            del_local_after_load: Whether to delete local copy after loading.
        """
        raise NotImplementedError
```

### FSDPEngine Implementaion

A concrete `FSDPEngine` implements all methods using PyTorch
FullyShardedDataParallel, supporting all the features that FSDP DPCritic
Worker support:

- Multi-GPU/model sharding  
- Activation- and optimizer-offload  
- LoRA & sequence parallelism  
- Dynamic batch size and remove padding

### CriticWorker Implementation based on the FSDPEngine
- Unchanged public API 
- Each role calls only BaseEngine methods (init_model,
train_mode/eval_mode, forward_backward_step, etc.)
- No modifications needed in existing algorithms (e.g., PPOTraining)
- New roles can be plugged in identically to legacy code

## Development Plan
We’ll roll this out in three gated phases, controlled by a feature-flag
(`use_legacy_worker_impl`).

### Phase 1: Engine Development
> Flag: use_legacy_worker_impl = True (default)
> New interface under active development

- Refactor Critic, Actor, Rollout, Ref to use only BaseEngine APIs
- Design a hierarchical, immutable config system for engine/backends
- Ensure PPO training curves and final accuracy match legacy
implementation

### Phase 2: Migration
> Flag: use_legacy_worker_impl = False (default) – legacy path logs a
deprecation warning
> All new code targets the new interface; 2–3 months of
integration/stress testing

- Enforce new interface for all feature work
- Gather benchmarks, bug reports, and performance data

### Phase 3: Cleanup
> After Phase 2 validation:
- Remove legacy worker code and flags
- Finalize documentation, update changelogs, close deprecation notices

Please review this refactor and share any feedback or concerns!
Contributions are welcome.
eric-haibin-lin added a commit that referenced this pull request Aug 3, 2025
…est, fix math_dataset path error (#2647)

### What does this PR do?

PR #1977 is a great job, I tried using the new engine and found some
minor problems and add ci test for FSDPEngine.
- Use newest name `gather_outputs_and_unpad` for the function
`gather_outputs_and_unpad`.
- Removed invalid calculations originally used for gradient accumulation
(gradient accumulation has been moved to loss_fn in new engine).
- Fixed misuses of two variable.

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).

---------

Signed-off-by: ShareLer <ShareLe@163.com>
Co-authored-by: eric-haibin-lin <linhaibin.eric@gmail.com>
Juniper1021 pushed a commit to Juniper1021/verl that referenced this pull request Aug 7, 2025
…olcengine#1977)

# [Refactor] Training Engine Interface and Development Plan

## Motivation  
See the original RFC for background:
volcengine#1371

Modernizing our training loop requires that we:

- **Decouple** training-backend implementation from algorithm code so
each can evolve independently
- **Unify** on a single, well-defined `Engine` interface across
FSDP/Megatron/etc backends
- **Enable** unit-testing of each backend implementation in isolation  
- **Guarantee** algorithm “roles” (Critic, Actor, Rollout, Ref) remain
completely engine-agnostic.

---

## Current Implementation  

This PR:
- Introduces an abstract `BaseEngine` class that defines a unified
training‐engine interface.
- Implements `FSDPEngine`, a concrete `BaseEngine` using PyTorch
FullyShardedDataParallel.
- Provides a `CriticWorker` based on `FSDPEngine` that plugs seamlessly
into existing PPO training code without any changes.


### Classic Training Loop with the New Interface

```python
# 1. Build and initialize engine
engine = FSDPEngine(config)
engine.init_model()
engine.set_loss_fn(loss_fn)

# 2. Training loop
for epoch in range(config.num_epochs):
    for batch in train_loader:
        # a) zero gradients
        engine.optimizer_zero_grad()

        # b) forward + backward
        with engine.train_mode():
            preds, loss, ctx = engine.forward_backward_step(
                batch,
                ctx,
                forward_only=False,
                preprocess_fn=preprocess_fn,
                postprocess_fn=postprocess_fn
            )

        # c) update and schedule
        grad_norm = engine.optimizer_step()
        current_lr = engine.lr_scheduler_step()

# 3. Evaluation
with engine.eval_mode():
    for micro_batch in data:
        preds, ctx = engine.forward_backward_step(
            micro_batch,
            ctx,
            forward_only=True,
            preprocess_fn=preprocess_fn,
            postprocess_fn=postprocess_fn
        )
```

### Detailed BaseEngine Interface
We now introduce an abstract base class, `BaseEngine`, which defines our
unified training-engine interface.

**Key enhancements over the original RFC:**
- **`train_mode()` / `eval_mode()`**  
Context managers to control parameter and activation load/offload at the
start and end of each loop.
- **`shard_data()` / `unshard_data()`**  
  APIs for partitioning and gathering data across devices or workers.  
- **`preprocess_fn` / `postprocess_fn` in `forward_backward_step()`**  
Hooks to apply custom transformations before and after each micro-batch
pass.

Below are the detailed signatures for each core method.

```python

class BaseEngine(object):
    """
    Abstract base class defining the interface for model training engines.

    Engine implementations must subclass BaseEngine and provide concrete behavior for all methods.
    """
    def __init__(self, config):
        """
        Initialize the BaseEngine.

        Args:
            config: Configuration object containing parameters for engine setup.
        """
        raise NotImplementedError

    def init_model(self):
        """
        Instantiate or load the model, optimizer, and learning rate scheduler.

        Should prepare all components necessary for training or evaluation.
        """
        raise NotImplementedError

    def train_mode(self):
        """
        Context manager entry for switching the engine and model into training mode.

        Usage:
            with engine.train_mode():
                # runs in training mode
        """
        raise NotImplementedError

    def eval_mode(self):        
        """
        Context manager entry for switching the engine and model into evaluation mode.

        Usage:
            with engine.eval_mode():
                # runs in evaluation mode
        """
        raise NotImplementedError

    def forward_backward_step(self, 
                              batch, 
                              ctx=None, 
                              forward_only=False, 
                              preprocess_fn=None, 
                              postprocess_fn=None):
        """
        Execute a forward pass (and optional backward pass) over a batch of data.

        Args:
            batch: Raw batch data (e.g., tensors or mappings) to process.
            ctx: Optional context dict passed to preprocess/postprocess functions.
            forward_only: If True, skip gradient computation and backward pass.
            preprocess_fn: Function(batch, ctx) -> (inputs, ctx), applied before model call.
            postprocess_fn: Function(outputs, ctx) -> (predictions, ctx), applied after model call.

        Returns:
            If forward_only:
                (predictions, ctx)
            Else:
                (predictions, loss, ctx)
        """
        raise NotImplementedError

    def optimizer_zero_grad(self):
        """
        Zero out gradients of all parameters before starting a new backward pass.
        """
        raise NotImplementedError

    def optimizer_step(self):
        """
        Perform an optimization step to update model parameters based on accumulated gradients.

        Returns:
            grad_norm (float): The norm of the gradients before clipping or update.
        """
        raise NotImplementedError

    def lr_scheduler_step(self):
        """
        Advance the learning rate scheduler by one step.

        Returns:
            current_lr (float or list[float]): Updated learning rate(s).
        """
        raise NotImplementedError

    def shard_data(self, data):
        """
        Shard or partition data for distributed training or parallel execution.

        Args:
            data: Data structure to be sharded across devices/workers.

        Returns:
            Sharded data in the same format as input.
        """
        raise NotImplementedError

    def unshard_data(self, data):
        """
        Reconstruct or gather sharded data back to a unified format.

        Args:
            data: Sharded data structure to reconstruct.

        Returns:
            Unsharded, combined data.
        """
        raise NotImplementedError
        

    def set_loss_fn(self, loss_fn):
        """
        Set the loss function to be used during training.

        Args:
            loss_fn: Callable(data, predictions, ctx) -> (loss_tensor, new_ctx)
        """
        raise NotImplementedError

    def to(self, device: str, model: bool = True, optimizer: bool = True):
        """
        Move model parameters, optimizer states, or both to the specified device.

        Args:
            device: Target device identifier (e.g., "cuda" or "cpu").
            model: If True, move the model.
            optimizer: If True, move the optimizer states.
        """
        raise NotImplementedError


    def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
        """
        Save model, optimizer, and scheduler states to a checkpoint.

        Args:
            local_path: Local filesystem path to save checkpoint.
            hdfs_path: Optional HDFS path to copy checkpoint.
            global_step: Integer training step number for naming.
            max_ckpt_to_keep: Maximum number of recent checkpoints to retain.
        """
        raise NotImplementedError


    def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True):
        """
        Load model, optimizer, and scheduler states from a checkpoint.

        Args:
            local_path: Local filesystem path of the checkpoint.
            hdfs_path: Optional HDFS path where checkpoint is stored.
            del_local_after_load: Whether to delete local copy after loading.
        """
        raise NotImplementedError
```

### FSDPEngine Implementaion

A concrete `FSDPEngine` implements all methods using PyTorch
FullyShardedDataParallel, supporting all the features that FSDP DPCritic
Worker support:

- Multi-GPU/model sharding  
- Activation- and optimizer-offload  
- LoRA & sequence parallelism  
- Dynamic batch size and remove padding

### CriticWorker Implementation based on the FSDPEngine
- Unchanged public API 
- Each role calls only BaseEngine methods (init_model,
train_mode/eval_mode, forward_backward_step, etc.)
- No modifications needed in existing algorithms (e.g., PPOTraining)
- New roles can be plugged in identically to legacy code

## Development Plan
We’ll roll this out in three gated phases, controlled by a feature-flag
(`use_legacy_worker_impl`).

### Phase 1: Engine Development
> Flag: use_legacy_worker_impl = True (default)
> New interface under active development

- Refactor Critic, Actor, Rollout, Ref to use only BaseEngine APIs
- Design a hierarchical, immutable config system for engine/backends
- Ensure PPO training curves and final accuracy match legacy
implementation

### Phase 2: Migration
> Flag: use_legacy_worker_impl = False (default) – legacy path logs a
deprecation warning
> All new code targets the new interface; 2–3 months of
integration/stress testing

- Enforce new interface for all feature work
- Gather benchmarks, bug reports, and performance data

### Phase 3: Cleanup
> After Phase 2 validation:
- Remove legacy worker code and flags
- Finalize documentation, update changelogs, close deprecation notices

Please review this refactor and share any feedback or concerns!
Contributions are welcome.
Juniper1021 pushed a commit to Juniper1021/verl that referenced this pull request Aug 7, 2025
…est, fix math_dataset path error (volcengine#2647)

### What does this PR do?

PR volcengine#1977 is a great job, I tried using the new engine and found some
minor problems and add ci test for FSDPEngine.
- Use newest name `gather_outputs_and_unpad` for the function
`gather_outputs_and_unpad`.
- Removed invalid calculations originally used for gradient accumulation
(gradient accumulation has been moved to loss_fn in new engine).
- Fixed misuses of two variable.

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).

---------

Signed-off-by: ShareLer <ShareLe@163.com>
Co-authored-by: eric-haibin-lin <linhaibin.eric@gmail.com>
HaeChan0305 pushed a commit to HaeChan0305/MLILAB-GRPO that referenced this pull request Aug 8, 2025
…est, fix math_dataset path error (volcengine#2647)

### What does this PR do?

PR volcengine#1977 is a great job, I tried using the new engine and found some
minor problems and add ci test for FSDPEngine.
- Use newest name `gather_outputs_and_unpad` for the function
`gather_outputs_and_unpad`.
- Removed invalid calculations originally used for gradient accumulation
(gradient accumulation has been moved to loss_fn in new engine).
- Fixed misuses of two variable.

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).

---------

Signed-off-by: ShareLer <ShareLe@163.com>
Co-authored-by: eric-haibin-lin <linhaibin.eric@gmail.com>
yellowbee686 pushed a commit to yellowbee686/verl that referenced this pull request Aug 11, 2025
…est, fix math_dataset path error (volcengine#2647)

### What does this PR do?

PR volcengine#1977 is a great job, I tried using the new engine and found some
minor problems and add ci test for FSDPEngine.
- Use newest name `gather_outputs_and_unpad` for the function
`gather_outputs_and_unpad`.
- Removed invalid calculations originally used for gradient accumulation
(gradient accumulation has been moved to loss_fn in new engine).
- Fixed misuses of two variable.

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).

---------

Signed-off-by: ShareLer <ShareLe@163.com>
Co-authored-by: eric-haibin-lin <linhaibin.eric@gmail.com>
whatadayG pushed a commit to whatadayG/verl that referenced this pull request Sep 5, 2025
…olcengine#1977)

# [Refactor] Training Engine Interface and Development Plan

## Motivation  
See the original RFC for background:
volcengine#1371

Modernizing our training loop requires that we:

- **Decouple** training-backend implementation from algorithm code so
each can evolve independently
- **Unify** on a single, well-defined `Engine` interface across
FSDP/Megatron/etc backends
- **Enable** unit-testing of each backend implementation in isolation  
- **Guarantee** algorithm “roles” (Critic, Actor, Rollout, Ref) remain
completely engine-agnostic.

---

## Current Implementation  

This PR:
- Introduces an abstract `BaseEngine` class that defines a unified
training‐engine interface.
- Implements `FSDPEngine`, a concrete `BaseEngine` using PyTorch
FullyShardedDataParallel.
- Provides a `CriticWorker` based on `FSDPEngine` that plugs seamlessly
into existing PPO training code without any changes.


### Classic Training Loop with the New Interface

```python
# 1. Build and initialize engine
engine = FSDPEngine(config)
engine.init_model()
engine.set_loss_fn(loss_fn)

# 2. Training loop
for epoch in range(config.num_epochs):
    for batch in train_loader:
        # a) zero gradients
        engine.optimizer_zero_grad()

        # b) forward + backward
        with engine.train_mode():
            preds, loss, ctx = engine.forward_backward_step(
                batch,
                ctx,
                forward_only=False,
                preprocess_fn=preprocess_fn,
                postprocess_fn=postprocess_fn
            )

        # c) update and schedule
        grad_norm = engine.optimizer_step()
        current_lr = engine.lr_scheduler_step()

# 3. Evaluation
with engine.eval_mode():
    for micro_batch in data:
        preds, ctx = engine.forward_backward_step(
            micro_batch,
            ctx,
            forward_only=True,
            preprocess_fn=preprocess_fn,
            postprocess_fn=postprocess_fn
        )
```

### Detailed BaseEngine Interface
We now introduce an abstract base class, `BaseEngine`, which defines our
unified training-engine interface.

**Key enhancements over the original RFC:**
- **`train_mode()` / `eval_mode()`**  
Context managers to control parameter and activation load/offload at the
start and end of each loop.
- **`shard_data()` / `unshard_data()`**  
  APIs for partitioning and gathering data across devices or workers.  
- **`preprocess_fn` / `postprocess_fn` in `forward_backward_step()`**  
Hooks to apply custom transformations before and after each micro-batch
pass.

Below are the detailed signatures for each core method.

```python

class BaseEngine(object):
    """
    Abstract base class defining the interface for model training engines.

    Engine implementations must subclass BaseEngine and provide concrete behavior for all methods.
    """
    def __init__(self, config):
        """
        Initialize the BaseEngine.

        Args:
            config: Configuration object containing parameters for engine setup.
        """
        raise NotImplementedError

    def init_model(self):
        """
        Instantiate or load the model, optimizer, and learning rate scheduler.

        Should prepare all components necessary for training or evaluation.
        """
        raise NotImplementedError

    def train_mode(self):
        """
        Context manager entry for switching the engine and model into training mode.

        Usage:
            with engine.train_mode():
                # runs in training mode
        """
        raise NotImplementedError

    def eval_mode(self):        
        """
        Context manager entry for switching the engine and model into evaluation mode.

        Usage:
            with engine.eval_mode():
                # runs in evaluation mode
        """
        raise NotImplementedError

    def forward_backward_step(self, 
                              batch, 
                              ctx=None, 
                              forward_only=False, 
                              preprocess_fn=None, 
                              postprocess_fn=None):
        """
        Execute a forward pass (and optional backward pass) over a batch of data.

        Args:
            batch: Raw batch data (e.g., tensors or mappings) to process.
            ctx: Optional context dict passed to preprocess/postprocess functions.
            forward_only: If True, skip gradient computation and backward pass.
            preprocess_fn: Function(batch, ctx) -> (inputs, ctx), applied before model call.
            postprocess_fn: Function(outputs, ctx) -> (predictions, ctx), applied after model call.

        Returns:
            If forward_only:
                (predictions, ctx)
            Else:
                (predictions, loss, ctx)
        """
        raise NotImplementedError

    def optimizer_zero_grad(self):
        """
        Zero out gradients of all parameters before starting a new backward pass.
        """
        raise NotImplementedError

    def optimizer_step(self):
        """
        Perform an optimization step to update model parameters based on accumulated gradients.

        Returns:
            grad_norm (float): The norm of the gradients before clipping or update.
        """
        raise NotImplementedError

    def lr_scheduler_step(self):
        """
        Advance the learning rate scheduler by one step.

        Returns:
            current_lr (float or list[float]): Updated learning rate(s).
        """
        raise NotImplementedError

    def shard_data(self, data):
        """
        Shard or partition data for distributed training or parallel execution.

        Args:
            data: Data structure to be sharded across devices/workers.

        Returns:
            Sharded data in the same format as input.
        """
        raise NotImplementedError

    def unshard_data(self, data):
        """
        Reconstruct or gather sharded data back to a unified format.

        Args:
            data: Sharded data structure to reconstruct.

        Returns:
            Unsharded, combined data.
        """
        raise NotImplementedError
        

    def set_loss_fn(self, loss_fn):
        """
        Set the loss function to be used during training.

        Args:
            loss_fn: Callable(data, predictions, ctx) -> (loss_tensor, new_ctx)
        """
        raise NotImplementedError

    def to(self, device: str, model: bool = True, optimizer: bool = True):
        """
        Move model parameters, optimizer states, or both to the specified device.

        Args:
            device: Target device identifier (e.g., "cuda" or "cpu").
            model: If True, move the model.
            optimizer: If True, move the optimizer states.
        """
        raise NotImplementedError


    def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
        """
        Save model, optimizer, and scheduler states to a checkpoint.

        Args:
            local_path: Local filesystem path to save checkpoint.
            hdfs_path: Optional HDFS path to copy checkpoint.
            global_step: Integer training step number for naming.
            max_ckpt_to_keep: Maximum number of recent checkpoints to retain.
        """
        raise NotImplementedError


    def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True):
        """
        Load model, optimizer, and scheduler states from a checkpoint.

        Args:
            local_path: Local filesystem path of the checkpoint.
            hdfs_path: Optional HDFS path where checkpoint is stored.
            del_local_after_load: Whether to delete local copy after loading.
        """
        raise NotImplementedError
```

### FSDPEngine Implementaion

A concrete `FSDPEngine` implements all methods using PyTorch
FullyShardedDataParallel, supporting all the features that FSDP DPCritic
Worker support:

- Multi-GPU/model sharding  
- Activation- and optimizer-offload  
- LoRA & sequence parallelism  
- Dynamic batch size and remove padding

### CriticWorker Implementation based on the FSDPEngine
- Unchanged public API 
- Each role calls only BaseEngine methods (init_model,
train_mode/eval_mode, forward_backward_step, etc.)
- No modifications needed in existing algorithms (e.g., PPOTraining)
- New roles can be plugged in identically to legacy code

## Development Plan
We’ll roll this out in three gated phases, controlled by a feature-flag
(`use_legacy_worker_impl`).

### Phase 1: Engine Development
> Flag: use_legacy_worker_impl = True (default)
> New interface under active development

- Refactor Critic, Actor, Rollout, Ref to use only BaseEngine APIs
- Design a hierarchical, immutable config system for engine/backends
- Ensure PPO training curves and final accuracy match legacy
implementation

### Phase 2: Migration
> Flag: use_legacy_worker_impl = False (default) – legacy path logs a
deprecation warning
> All new code targets the new interface; 2–3 months of
integration/stress testing

- Enforce new interface for all feature work
- Gather benchmarks, bug reports, and performance data

### Phase 3: Cleanup
> After Phase 2 validation:
- Remove legacy worker code and flags
- Finalize documentation, update changelogs, close deprecation notices

Please review this refactor and share any feedback or concerns!
Contributions are welcome.
whatadayG pushed a commit to whatadayG/verl that referenced this pull request Sep 5, 2025
…est, fix math_dataset path error (volcengine#2647)

### What does this PR do?

PR volcengine#1977 is a great job, I tried using the new engine and found some
minor problems and add ci test for FSDPEngine.
- Use newest name `gather_outputs_and_unpad` for the function
`gather_outputs_and_unpad`.
- Removed invalid calculations originally used for gradient accumulation
(gradient accumulation has been moved to loss_fn in new engine).
- Fixed misuses of two variable.

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).

---------

Signed-off-by: ShareLer <ShareLe@163.com>
Co-authored-by: eric-haibin-lin <linhaibin.eric@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants