-
Notifications
You must be signed in to change notification settings - Fork 2.3k
[trainer] refactor: Training Engine Interface and Development Plan #1977
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[trainer] refactor: Training Engine Interface and Development Plan #1977
Conversation
@ccclyu Please given some feedbacks as the same process will need for MegatronWorker. The current design philosophy is
|
@vermouth1992 @eric-haibin-lin @PeterSH6 @tongyx361 @ETOgaosion @hongpeng-guo @wwwjn @tianyu-l |
Checklist
|
Meaning FSDP1? |
I also migrated the FSDP2 implementation in the PR. verl/verl/workers/engine/fsdp/engine_impl.py Line 248 in a745948
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello, I have some comments from aspect of megatron compatibility.
verl/workers/roles/critic.py
Outdated
responses = batch["responses"] | ||
attention_mask = batch["attention_mask"] | ||
values = batch["values"] | ||
returns = batch["returns"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we have a protocol for batch?
@vermouth1992 @ISEEKYAN made several updates:
Let me know whether current interface looks good to you guys. |
1f606af
to
f98df19
Compare
output = output.to("cpu") | ||
return output | ||
|
||
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will need to decide how we dispatch here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you elaborate on this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we'll work on that for megatron backend
Supports model sharding, activation/optimizer offloading, LoRA, and sequence parallelism. | ||
""" | ||
|
||
def __init__(self, config): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we define a dataclass for this config?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And convert hydra config to dataclass config at worker level
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
320b55a
to
41e2333
Compare
…olcengine#1977) # [Refactor] Training Engine Interface and Development Plan ## Motivation See the original RFC for background: volcengine#1371 Modernizing our training loop requires that we: - **Decouple** training-backend implementation from algorithm code so each can evolve independently - **Unify** on a single, well-defined `Engine` interface across FSDP/Megatron/etc backends - **Enable** unit-testing of each backend implementation in isolation - **Guarantee** algorithm “roles” (Critic, Actor, Rollout, Ref) remain completely engine-agnostic. --- ## Current Implementation This PR: - Introduces an abstract `BaseEngine` class that defines a unified training‐engine interface. - Implements `FSDPEngine`, a concrete `BaseEngine` using PyTorch FullyShardedDataParallel. - Provides a `CriticWorker` based on `FSDPEngine` that plugs seamlessly into existing PPO training code without any changes. ### Classic Training Loop with the New Interface ```python # 1. Build and initialize engine engine = FSDPEngine(config) engine.init_model() engine.set_loss_fn(loss_fn) # 2. Training loop for epoch in range(config.num_epochs): for batch in train_loader: # a) zero gradients engine.optimizer_zero_grad() # b) forward + backward with engine.train_mode(): preds, loss, ctx = engine.forward_backward_step( batch, ctx, forward_only=False, preprocess_fn=preprocess_fn, postprocess_fn=postprocess_fn ) # c) update and schedule grad_norm = engine.optimizer_step() current_lr = engine.lr_scheduler_step() # 3. Evaluation with engine.eval_mode(): for micro_batch in data: preds, ctx = engine.forward_backward_step( micro_batch, ctx, forward_only=True, preprocess_fn=preprocess_fn, postprocess_fn=postprocess_fn ) ``` ### Detailed BaseEngine Interface We now introduce an abstract base class, `BaseEngine`, which defines our unified training-engine interface. **Key enhancements over the original RFC:** - **`train_mode()` / `eval_mode()`** Context managers to control parameter and activation load/offload at the start and end of each loop. - **`shard_data()` / `unshard_data()`** APIs for partitioning and gathering data across devices or workers. - **`preprocess_fn` / `postprocess_fn` in `forward_backward_step()`** Hooks to apply custom transformations before and after each micro-batch pass. Below are the detailed signatures for each core method. ```python class BaseEngine(object): """ Abstract base class defining the interface for model training engines. Engine implementations must subclass BaseEngine and provide concrete behavior for all methods. """ def __init__(self, config): """ Initialize the BaseEngine. Args: config: Configuration object containing parameters for engine setup. """ raise NotImplementedError def init_model(self): """ Instantiate or load the model, optimizer, and learning rate scheduler. Should prepare all components necessary for training or evaluation. """ raise NotImplementedError def train_mode(self): """ Context manager entry for switching the engine and model into training mode. Usage: with engine.train_mode(): # runs in training mode """ raise NotImplementedError def eval_mode(self): """ Context manager entry for switching the engine and model into evaluation mode. Usage: with engine.eval_mode(): # runs in evaluation mode """ raise NotImplementedError def forward_backward_step(self, batch, ctx=None, forward_only=False, preprocess_fn=None, postprocess_fn=None): """ Execute a forward pass (and optional backward pass) over a batch of data. Args: batch: Raw batch data (e.g., tensors or mappings) to process. ctx: Optional context dict passed to preprocess/postprocess functions. forward_only: If True, skip gradient computation and backward pass. preprocess_fn: Function(batch, ctx) -> (inputs, ctx), applied before model call. postprocess_fn: Function(outputs, ctx) -> (predictions, ctx), applied after model call. Returns: If forward_only: (predictions, ctx) Else: (predictions, loss, ctx) """ raise NotImplementedError def optimizer_zero_grad(self): """ Zero out gradients of all parameters before starting a new backward pass. """ raise NotImplementedError def optimizer_step(self): """ Perform an optimization step to update model parameters based on accumulated gradients. Returns: grad_norm (float): The norm of the gradients before clipping or update. """ raise NotImplementedError def lr_scheduler_step(self): """ Advance the learning rate scheduler by one step. Returns: current_lr (float or list[float]): Updated learning rate(s). """ raise NotImplementedError def shard_data(self, data): """ Shard or partition data for distributed training or parallel execution. Args: data: Data structure to be sharded across devices/workers. Returns: Sharded data in the same format as input. """ raise NotImplementedError def unshard_data(self, data): """ Reconstruct or gather sharded data back to a unified format. Args: data: Sharded data structure to reconstruct. Returns: Unsharded, combined data. """ raise NotImplementedError def set_loss_fn(self, loss_fn): """ Set the loss function to be used during training. Args: loss_fn: Callable(data, predictions, ctx) -> (loss_tensor, new_ctx) """ raise NotImplementedError def to(self, device: str, model: bool = True, optimizer: bool = True): """ Move model parameters, optimizer states, or both to the specified device. Args: device: Target device identifier (e.g., "cuda" or "cpu"). model: If True, move the model. optimizer: If True, move the optimizer states. """ raise NotImplementedError def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): """ Save model, optimizer, and scheduler states to a checkpoint. Args: local_path: Local filesystem path to save checkpoint. hdfs_path: Optional HDFS path to copy checkpoint. global_step: Integer training step number for naming. max_ckpt_to_keep: Maximum number of recent checkpoints to retain. """ raise NotImplementedError def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True): """ Load model, optimizer, and scheduler states from a checkpoint. Args: local_path: Local filesystem path of the checkpoint. hdfs_path: Optional HDFS path where checkpoint is stored. del_local_after_load: Whether to delete local copy after loading. """ raise NotImplementedError ``` ### FSDPEngine Implementaion A concrete `FSDPEngine` implements all methods using PyTorch FullyShardedDataParallel, supporting all the features that FSDP DPCritic Worker support: - Multi-GPU/model sharding - Activation- and optimizer-offload - LoRA & sequence parallelism - Dynamic batch size and remove padding ### CriticWorker Implementation based on the FSDPEngine - Unchanged public API - Each role calls only BaseEngine methods (init_model, train_mode/eval_mode, forward_backward_step, etc.) - No modifications needed in existing algorithms (e.g., PPOTraining) - New roles can be plugged in identically to legacy code ## Development Plan We’ll roll this out in three gated phases, controlled by a feature-flag (`use_legacy_worker_impl`). ### Phase 1: Engine Development > Flag: use_legacy_worker_impl = True (default) > New interface under active development - Refactor Critic, Actor, Rollout, Ref to use only BaseEngine APIs - Design a hierarchical, immutable config system for engine/backends - Ensure PPO training curves and final accuracy match legacy implementation ### Phase 2: Migration > Flag: use_legacy_worker_impl = False (default) – legacy path logs a deprecation warning > All new code targets the new interface; 2–3 months of integration/stress testing - Enforce new interface for all feature work - Gather benchmarks, bug reports, and performance data ### Phase 3: Cleanup > After Phase 2 validation: - Remove legacy worker code and flags - Finalize documentation, update changelogs, close deprecation notices Please review this refactor and share any feedback or concerns! Contributions are welcome.
…olcengine#1977) # [Refactor] Training Engine Interface and Development Plan ## Motivation See the original RFC for background: volcengine#1371 Modernizing our training loop requires that we: - **Decouple** training-backend implementation from algorithm code so each can evolve independently - **Unify** on a single, well-defined `Engine` interface across FSDP/Megatron/etc backends - **Enable** unit-testing of each backend implementation in isolation - **Guarantee** algorithm “roles” (Critic, Actor, Rollout, Ref) remain completely engine-agnostic. --- ## Current Implementation This PR: - Introduces an abstract `BaseEngine` class that defines a unified training‐engine interface. - Implements `FSDPEngine`, a concrete `BaseEngine` using PyTorch FullyShardedDataParallel. - Provides a `CriticWorker` based on `FSDPEngine` that plugs seamlessly into existing PPO training code without any changes. ### Classic Training Loop with the New Interface ```python # 1. Build and initialize engine engine = FSDPEngine(config) engine.init_model() engine.set_loss_fn(loss_fn) # 2. Training loop for epoch in range(config.num_epochs): for batch in train_loader: # a) zero gradients engine.optimizer_zero_grad() # b) forward + backward with engine.train_mode(): preds, loss, ctx = engine.forward_backward_step( batch, ctx, forward_only=False, preprocess_fn=preprocess_fn, postprocess_fn=postprocess_fn ) # c) update and schedule grad_norm = engine.optimizer_step() current_lr = engine.lr_scheduler_step() # 3. Evaluation with engine.eval_mode(): for micro_batch in data: preds, ctx = engine.forward_backward_step( micro_batch, ctx, forward_only=True, preprocess_fn=preprocess_fn, postprocess_fn=postprocess_fn ) ``` ### Detailed BaseEngine Interface We now introduce an abstract base class, `BaseEngine`, which defines our unified training-engine interface. **Key enhancements over the original RFC:** - **`train_mode()` / `eval_mode()`** Context managers to control parameter and activation load/offload at the start and end of each loop. - **`shard_data()` / `unshard_data()`** APIs for partitioning and gathering data across devices or workers. - **`preprocess_fn` / `postprocess_fn` in `forward_backward_step()`** Hooks to apply custom transformations before and after each micro-batch pass. Below are the detailed signatures for each core method. ```python class BaseEngine(object): """ Abstract base class defining the interface for model training engines. Engine implementations must subclass BaseEngine and provide concrete behavior for all methods. """ def __init__(self, config): """ Initialize the BaseEngine. Args: config: Configuration object containing parameters for engine setup. """ raise NotImplementedError def init_model(self): """ Instantiate or load the model, optimizer, and learning rate scheduler. Should prepare all components necessary for training or evaluation. """ raise NotImplementedError def train_mode(self): """ Context manager entry for switching the engine and model into training mode. Usage: with engine.train_mode(): # runs in training mode """ raise NotImplementedError def eval_mode(self): """ Context manager entry for switching the engine and model into evaluation mode. Usage: with engine.eval_mode(): # runs in evaluation mode """ raise NotImplementedError def forward_backward_step(self, batch, ctx=None, forward_only=False, preprocess_fn=None, postprocess_fn=None): """ Execute a forward pass (and optional backward pass) over a batch of data. Args: batch: Raw batch data (e.g., tensors or mappings) to process. ctx: Optional context dict passed to preprocess/postprocess functions. forward_only: If True, skip gradient computation and backward pass. preprocess_fn: Function(batch, ctx) -> (inputs, ctx), applied before model call. postprocess_fn: Function(outputs, ctx) -> (predictions, ctx), applied after model call. Returns: If forward_only: (predictions, ctx) Else: (predictions, loss, ctx) """ raise NotImplementedError def optimizer_zero_grad(self): """ Zero out gradients of all parameters before starting a new backward pass. """ raise NotImplementedError def optimizer_step(self): """ Perform an optimization step to update model parameters based on accumulated gradients. Returns: grad_norm (float): The norm of the gradients before clipping or update. """ raise NotImplementedError def lr_scheduler_step(self): """ Advance the learning rate scheduler by one step. Returns: current_lr (float or list[float]): Updated learning rate(s). """ raise NotImplementedError def shard_data(self, data): """ Shard or partition data for distributed training or parallel execution. Args: data: Data structure to be sharded across devices/workers. Returns: Sharded data in the same format as input. """ raise NotImplementedError def unshard_data(self, data): """ Reconstruct or gather sharded data back to a unified format. Args: data: Sharded data structure to reconstruct. Returns: Unsharded, combined data. """ raise NotImplementedError def set_loss_fn(self, loss_fn): """ Set the loss function to be used during training. Args: loss_fn: Callable(data, predictions, ctx) -> (loss_tensor, new_ctx) """ raise NotImplementedError def to(self, device: str, model: bool = True, optimizer: bool = True): """ Move model parameters, optimizer states, or both to the specified device. Args: device: Target device identifier (e.g., "cuda" or "cpu"). model: If True, move the model. optimizer: If True, move the optimizer states. """ raise NotImplementedError def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): """ Save model, optimizer, and scheduler states to a checkpoint. Args: local_path: Local filesystem path to save checkpoint. hdfs_path: Optional HDFS path to copy checkpoint. global_step: Integer training step number for naming. max_ckpt_to_keep: Maximum number of recent checkpoints to retain. """ raise NotImplementedError def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True): """ Load model, optimizer, and scheduler states from a checkpoint. Args: local_path: Local filesystem path of the checkpoint. hdfs_path: Optional HDFS path where checkpoint is stored. del_local_after_load: Whether to delete local copy after loading. """ raise NotImplementedError ``` ### FSDPEngine Implementaion A concrete `FSDPEngine` implements all methods using PyTorch FullyShardedDataParallel, supporting all the features that FSDP DPCritic Worker support: - Multi-GPU/model sharding - Activation- and optimizer-offload - LoRA & sequence parallelism - Dynamic batch size and remove padding ### CriticWorker Implementation based on the FSDPEngine - Unchanged public API - Each role calls only BaseEngine methods (init_model, train_mode/eval_mode, forward_backward_step, etc.) - No modifications needed in existing algorithms (e.g., PPOTraining) - New roles can be plugged in identically to legacy code ## Development Plan We’ll roll this out in three gated phases, controlled by a feature-flag (`use_legacy_worker_impl`). ### Phase 1: Engine Development > Flag: use_legacy_worker_impl = True (default) > New interface under active development - Refactor Critic, Actor, Rollout, Ref to use only BaseEngine APIs - Design a hierarchical, immutable config system for engine/backends - Ensure PPO training curves and final accuracy match legacy implementation ### Phase 2: Migration > Flag: use_legacy_worker_impl = False (default) – legacy path logs a deprecation warning > All new code targets the new interface; 2–3 months of integration/stress testing - Enforce new interface for all feature work - Gather benchmarks, bug reports, and performance data ### Phase 3: Cleanup > After Phase 2 validation: - Remove legacy worker code and flags - Finalize documentation, update changelogs, close deprecation notices Please review this refactor and share any feedback or concerns! Contributions are welcome.
…olcengine#1977) # [Refactor] Training Engine Interface and Development Plan ## Motivation See the original RFC for background: volcengine#1371 Modernizing our training loop requires that we: - **Decouple** training-backend implementation from algorithm code so each can evolve independently - **Unify** on a single, well-defined `Engine` interface across FSDP/Megatron/etc backends - **Enable** unit-testing of each backend implementation in isolation - **Guarantee** algorithm “roles” (Critic, Actor, Rollout, Ref) remain completely engine-agnostic. --- ## Current Implementation This PR: - Introduces an abstract `BaseEngine` class that defines a unified training‐engine interface. - Implements `FSDPEngine`, a concrete `BaseEngine` using PyTorch FullyShardedDataParallel. - Provides a `CriticWorker` based on `FSDPEngine` that plugs seamlessly into existing PPO training code without any changes. ### Classic Training Loop with the New Interface ```python # 1. Build and initialize engine engine = FSDPEngine(config) engine.init_model() engine.set_loss_fn(loss_fn) # 2. Training loop for epoch in range(config.num_epochs): for batch in train_loader: # a) zero gradients engine.optimizer_zero_grad() # b) forward + backward with engine.train_mode(): preds, loss, ctx = engine.forward_backward_step( batch, ctx, forward_only=False, preprocess_fn=preprocess_fn, postprocess_fn=postprocess_fn ) # c) update and schedule grad_norm = engine.optimizer_step() current_lr = engine.lr_scheduler_step() # 3. Evaluation with engine.eval_mode(): for micro_batch in data: preds, ctx = engine.forward_backward_step( micro_batch, ctx, forward_only=True, preprocess_fn=preprocess_fn, postprocess_fn=postprocess_fn ) ``` ### Detailed BaseEngine Interface We now introduce an abstract base class, `BaseEngine`, which defines our unified training-engine interface. **Key enhancements over the original RFC:** - **`train_mode()` / `eval_mode()`** Context managers to control parameter and activation load/offload at the start and end of each loop. - **`shard_data()` / `unshard_data()`** APIs for partitioning and gathering data across devices or workers. - **`preprocess_fn` / `postprocess_fn` in `forward_backward_step()`** Hooks to apply custom transformations before and after each micro-batch pass. Below are the detailed signatures for each core method. ```python class BaseEngine(object): """ Abstract base class defining the interface for model training engines. Engine implementations must subclass BaseEngine and provide concrete behavior for all methods. """ def __init__(self, config): """ Initialize the BaseEngine. Args: config: Configuration object containing parameters for engine setup. """ raise NotImplementedError def init_model(self): """ Instantiate or load the model, optimizer, and learning rate scheduler. Should prepare all components necessary for training or evaluation. """ raise NotImplementedError def train_mode(self): """ Context manager entry for switching the engine and model into training mode. Usage: with engine.train_mode(): # runs in training mode """ raise NotImplementedError def eval_mode(self): """ Context manager entry for switching the engine and model into evaluation mode. Usage: with engine.eval_mode(): # runs in evaluation mode """ raise NotImplementedError def forward_backward_step(self, batch, ctx=None, forward_only=False, preprocess_fn=None, postprocess_fn=None): """ Execute a forward pass (and optional backward pass) over a batch of data. Args: batch: Raw batch data (e.g., tensors or mappings) to process. ctx: Optional context dict passed to preprocess/postprocess functions. forward_only: If True, skip gradient computation and backward pass. preprocess_fn: Function(batch, ctx) -> (inputs, ctx), applied before model call. postprocess_fn: Function(outputs, ctx) -> (predictions, ctx), applied after model call. Returns: If forward_only: (predictions, ctx) Else: (predictions, loss, ctx) """ raise NotImplementedError def optimizer_zero_grad(self): """ Zero out gradients of all parameters before starting a new backward pass. """ raise NotImplementedError def optimizer_step(self): """ Perform an optimization step to update model parameters based on accumulated gradients. Returns: grad_norm (float): The norm of the gradients before clipping or update. """ raise NotImplementedError def lr_scheduler_step(self): """ Advance the learning rate scheduler by one step. Returns: current_lr (float or list[float]): Updated learning rate(s). """ raise NotImplementedError def shard_data(self, data): """ Shard or partition data for distributed training or parallel execution. Args: data: Data structure to be sharded across devices/workers. Returns: Sharded data in the same format as input. """ raise NotImplementedError def unshard_data(self, data): """ Reconstruct or gather sharded data back to a unified format. Args: data: Sharded data structure to reconstruct. Returns: Unsharded, combined data. """ raise NotImplementedError def set_loss_fn(self, loss_fn): """ Set the loss function to be used during training. Args: loss_fn: Callable(data, predictions, ctx) -> (loss_tensor, new_ctx) """ raise NotImplementedError def to(self, device: str, model: bool = True, optimizer: bool = True): """ Move model parameters, optimizer states, or both to the specified device. Args: device: Target device identifier (e.g., "cuda" or "cpu"). model: If True, move the model. optimizer: If True, move the optimizer states. """ raise NotImplementedError def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): """ Save model, optimizer, and scheduler states to a checkpoint. Args: local_path: Local filesystem path to save checkpoint. hdfs_path: Optional HDFS path to copy checkpoint. global_step: Integer training step number for naming. max_ckpt_to_keep: Maximum number of recent checkpoints to retain. """ raise NotImplementedError def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True): """ Load model, optimizer, and scheduler states from a checkpoint. Args: local_path: Local filesystem path of the checkpoint. hdfs_path: Optional HDFS path where checkpoint is stored. del_local_after_load: Whether to delete local copy after loading. """ raise NotImplementedError ``` ### FSDPEngine Implementaion A concrete `FSDPEngine` implements all methods using PyTorch FullyShardedDataParallel, supporting all the features that FSDP DPCritic Worker support: - Multi-GPU/model sharding - Activation- and optimizer-offload - LoRA & sequence parallelism - Dynamic batch size and remove padding ### CriticWorker Implementation based on the FSDPEngine - Unchanged public API - Each role calls only BaseEngine methods (init_model, train_mode/eval_mode, forward_backward_step, etc.) - No modifications needed in existing algorithms (e.g., PPOTraining) - New roles can be plugged in identically to legacy code ## Development Plan We’ll roll this out in three gated phases, controlled by a feature-flag (`use_legacy_worker_impl`). ### Phase 1: Engine Development > Flag: use_legacy_worker_impl = True (default) > New interface under active development - Refactor Critic, Actor, Rollout, Ref to use only BaseEngine APIs - Design a hierarchical, immutable config system for engine/backends - Ensure PPO training curves and final accuracy match legacy implementation ### Phase 2: Migration > Flag: use_legacy_worker_impl = False (default) – legacy path logs a deprecation warning > All new code targets the new interface; 2–3 months of integration/stress testing - Enforce new interface for all feature work - Gather benchmarks, bug reports, and performance data ### Phase 3: Cleanup > After Phase 2 validation: - Remove legacy worker code and flags - Finalize documentation, update changelogs, close deprecation notices Please review this refactor and share any feedback or concerns! Contributions are welcome.
…est, fix math_dataset path error (#2647) ### What does this PR do? PR #1977 is a great job, I tried using the new engine and found some minor problems and add ci test for FSDPEngine. - Use newest name `gather_outputs_and_unpad` for the function `gather_outputs_and_unpad`. - Removed invalid calculations originally used for gradient accumulation (gradient accumulation has been moved to loss_fn in new engine). - Fixed misuses of two variable. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). --------- Signed-off-by: ShareLer <ShareLe@163.com> Co-authored-by: eric-haibin-lin <linhaibin.eric@gmail.com>
…olcengine#1977) # [Refactor] Training Engine Interface and Development Plan ## Motivation See the original RFC for background: volcengine#1371 Modernizing our training loop requires that we: - **Decouple** training-backend implementation from algorithm code so each can evolve independently - **Unify** on a single, well-defined `Engine` interface across FSDP/Megatron/etc backends - **Enable** unit-testing of each backend implementation in isolation - **Guarantee** algorithm “roles” (Critic, Actor, Rollout, Ref) remain completely engine-agnostic. --- ## Current Implementation This PR: - Introduces an abstract `BaseEngine` class that defines a unified training‐engine interface. - Implements `FSDPEngine`, a concrete `BaseEngine` using PyTorch FullyShardedDataParallel. - Provides a `CriticWorker` based on `FSDPEngine` that plugs seamlessly into existing PPO training code without any changes. ### Classic Training Loop with the New Interface ```python # 1. Build and initialize engine engine = FSDPEngine(config) engine.init_model() engine.set_loss_fn(loss_fn) # 2. Training loop for epoch in range(config.num_epochs): for batch in train_loader: # a) zero gradients engine.optimizer_zero_grad() # b) forward + backward with engine.train_mode(): preds, loss, ctx = engine.forward_backward_step( batch, ctx, forward_only=False, preprocess_fn=preprocess_fn, postprocess_fn=postprocess_fn ) # c) update and schedule grad_norm = engine.optimizer_step() current_lr = engine.lr_scheduler_step() # 3. Evaluation with engine.eval_mode(): for micro_batch in data: preds, ctx = engine.forward_backward_step( micro_batch, ctx, forward_only=True, preprocess_fn=preprocess_fn, postprocess_fn=postprocess_fn ) ``` ### Detailed BaseEngine Interface We now introduce an abstract base class, `BaseEngine`, which defines our unified training-engine interface. **Key enhancements over the original RFC:** - **`train_mode()` / `eval_mode()`** Context managers to control parameter and activation load/offload at the start and end of each loop. - **`shard_data()` / `unshard_data()`** APIs for partitioning and gathering data across devices or workers. - **`preprocess_fn` / `postprocess_fn` in `forward_backward_step()`** Hooks to apply custom transformations before and after each micro-batch pass. Below are the detailed signatures for each core method. ```python class BaseEngine(object): """ Abstract base class defining the interface for model training engines. Engine implementations must subclass BaseEngine and provide concrete behavior for all methods. """ def __init__(self, config): """ Initialize the BaseEngine. Args: config: Configuration object containing parameters for engine setup. """ raise NotImplementedError def init_model(self): """ Instantiate or load the model, optimizer, and learning rate scheduler. Should prepare all components necessary for training or evaluation. """ raise NotImplementedError def train_mode(self): """ Context manager entry for switching the engine and model into training mode. Usage: with engine.train_mode(): # runs in training mode """ raise NotImplementedError def eval_mode(self): """ Context manager entry for switching the engine and model into evaluation mode. Usage: with engine.eval_mode(): # runs in evaluation mode """ raise NotImplementedError def forward_backward_step(self, batch, ctx=None, forward_only=False, preprocess_fn=None, postprocess_fn=None): """ Execute a forward pass (and optional backward pass) over a batch of data. Args: batch: Raw batch data (e.g., tensors or mappings) to process. ctx: Optional context dict passed to preprocess/postprocess functions. forward_only: If True, skip gradient computation and backward pass. preprocess_fn: Function(batch, ctx) -> (inputs, ctx), applied before model call. postprocess_fn: Function(outputs, ctx) -> (predictions, ctx), applied after model call. Returns: If forward_only: (predictions, ctx) Else: (predictions, loss, ctx) """ raise NotImplementedError def optimizer_zero_grad(self): """ Zero out gradients of all parameters before starting a new backward pass. """ raise NotImplementedError def optimizer_step(self): """ Perform an optimization step to update model parameters based on accumulated gradients. Returns: grad_norm (float): The norm of the gradients before clipping or update. """ raise NotImplementedError def lr_scheduler_step(self): """ Advance the learning rate scheduler by one step. Returns: current_lr (float or list[float]): Updated learning rate(s). """ raise NotImplementedError def shard_data(self, data): """ Shard or partition data for distributed training or parallel execution. Args: data: Data structure to be sharded across devices/workers. Returns: Sharded data in the same format as input. """ raise NotImplementedError def unshard_data(self, data): """ Reconstruct or gather sharded data back to a unified format. Args: data: Sharded data structure to reconstruct. Returns: Unsharded, combined data. """ raise NotImplementedError def set_loss_fn(self, loss_fn): """ Set the loss function to be used during training. Args: loss_fn: Callable(data, predictions, ctx) -> (loss_tensor, new_ctx) """ raise NotImplementedError def to(self, device: str, model: bool = True, optimizer: bool = True): """ Move model parameters, optimizer states, or both to the specified device. Args: device: Target device identifier (e.g., "cuda" or "cpu"). model: If True, move the model. optimizer: If True, move the optimizer states. """ raise NotImplementedError def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): """ Save model, optimizer, and scheduler states to a checkpoint. Args: local_path: Local filesystem path to save checkpoint. hdfs_path: Optional HDFS path to copy checkpoint. global_step: Integer training step number for naming. max_ckpt_to_keep: Maximum number of recent checkpoints to retain. """ raise NotImplementedError def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True): """ Load model, optimizer, and scheduler states from a checkpoint. Args: local_path: Local filesystem path of the checkpoint. hdfs_path: Optional HDFS path where checkpoint is stored. del_local_after_load: Whether to delete local copy after loading. """ raise NotImplementedError ``` ### FSDPEngine Implementaion A concrete `FSDPEngine` implements all methods using PyTorch FullyShardedDataParallel, supporting all the features that FSDP DPCritic Worker support: - Multi-GPU/model sharding - Activation- and optimizer-offload - LoRA & sequence parallelism - Dynamic batch size and remove padding ### CriticWorker Implementation based on the FSDPEngine - Unchanged public API - Each role calls only BaseEngine methods (init_model, train_mode/eval_mode, forward_backward_step, etc.) - No modifications needed in existing algorithms (e.g., PPOTraining) - New roles can be plugged in identically to legacy code ## Development Plan We’ll roll this out in three gated phases, controlled by a feature-flag (`use_legacy_worker_impl`). ### Phase 1: Engine Development > Flag: use_legacy_worker_impl = True (default) > New interface under active development - Refactor Critic, Actor, Rollout, Ref to use only BaseEngine APIs - Design a hierarchical, immutable config system for engine/backends - Ensure PPO training curves and final accuracy match legacy implementation ### Phase 2: Migration > Flag: use_legacy_worker_impl = False (default) – legacy path logs a deprecation warning > All new code targets the new interface; 2–3 months of integration/stress testing - Enforce new interface for all feature work - Gather benchmarks, bug reports, and performance data ### Phase 3: Cleanup > After Phase 2 validation: - Remove legacy worker code and flags - Finalize documentation, update changelogs, close deprecation notices Please review this refactor and share any feedback or concerns! Contributions are welcome.
…est, fix math_dataset path error (volcengine#2647) ### What does this PR do? PR volcengine#1977 is a great job, I tried using the new engine and found some minor problems and add ci test for FSDPEngine. - Use newest name `gather_outputs_and_unpad` for the function `gather_outputs_and_unpad`. - Removed invalid calculations originally used for gradient accumulation (gradient accumulation has been moved to loss_fn in new engine). - Fixed misuses of two variable. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). --------- Signed-off-by: ShareLer <ShareLe@163.com> Co-authored-by: eric-haibin-lin <linhaibin.eric@gmail.com>
…est, fix math_dataset path error (volcengine#2647) ### What does this PR do? PR volcengine#1977 is a great job, I tried using the new engine and found some minor problems and add ci test for FSDPEngine. - Use newest name `gather_outputs_and_unpad` for the function `gather_outputs_and_unpad`. - Removed invalid calculations originally used for gradient accumulation (gradient accumulation has been moved to loss_fn in new engine). - Fixed misuses of two variable. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). --------- Signed-off-by: ShareLer <ShareLe@163.com> Co-authored-by: eric-haibin-lin <linhaibin.eric@gmail.com>
…est, fix math_dataset path error (volcengine#2647) ### What does this PR do? PR volcengine#1977 is a great job, I tried using the new engine and found some minor problems and add ci test for FSDPEngine. - Use newest name `gather_outputs_and_unpad` for the function `gather_outputs_and_unpad`. - Removed invalid calculations originally used for gradient accumulation (gradient accumulation has been moved to loss_fn in new engine). - Fixed misuses of two variable. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). --------- Signed-off-by: ShareLer <ShareLe@163.com> Co-authored-by: eric-haibin-lin <linhaibin.eric@gmail.com>
…olcengine#1977) # [Refactor] Training Engine Interface and Development Plan ## Motivation See the original RFC for background: volcengine#1371 Modernizing our training loop requires that we: - **Decouple** training-backend implementation from algorithm code so each can evolve independently - **Unify** on a single, well-defined `Engine` interface across FSDP/Megatron/etc backends - **Enable** unit-testing of each backend implementation in isolation - **Guarantee** algorithm “roles” (Critic, Actor, Rollout, Ref) remain completely engine-agnostic. --- ## Current Implementation This PR: - Introduces an abstract `BaseEngine` class that defines a unified training‐engine interface. - Implements `FSDPEngine`, a concrete `BaseEngine` using PyTorch FullyShardedDataParallel. - Provides a `CriticWorker` based on `FSDPEngine` that plugs seamlessly into existing PPO training code without any changes. ### Classic Training Loop with the New Interface ```python # 1. Build and initialize engine engine = FSDPEngine(config) engine.init_model() engine.set_loss_fn(loss_fn) # 2. Training loop for epoch in range(config.num_epochs): for batch in train_loader: # a) zero gradients engine.optimizer_zero_grad() # b) forward + backward with engine.train_mode(): preds, loss, ctx = engine.forward_backward_step( batch, ctx, forward_only=False, preprocess_fn=preprocess_fn, postprocess_fn=postprocess_fn ) # c) update and schedule grad_norm = engine.optimizer_step() current_lr = engine.lr_scheduler_step() # 3. Evaluation with engine.eval_mode(): for micro_batch in data: preds, ctx = engine.forward_backward_step( micro_batch, ctx, forward_only=True, preprocess_fn=preprocess_fn, postprocess_fn=postprocess_fn ) ``` ### Detailed BaseEngine Interface We now introduce an abstract base class, `BaseEngine`, which defines our unified training-engine interface. **Key enhancements over the original RFC:** - **`train_mode()` / `eval_mode()`** Context managers to control parameter and activation load/offload at the start and end of each loop. - **`shard_data()` / `unshard_data()`** APIs for partitioning and gathering data across devices or workers. - **`preprocess_fn` / `postprocess_fn` in `forward_backward_step()`** Hooks to apply custom transformations before and after each micro-batch pass. Below are the detailed signatures for each core method. ```python class BaseEngine(object): """ Abstract base class defining the interface for model training engines. Engine implementations must subclass BaseEngine and provide concrete behavior for all methods. """ def __init__(self, config): """ Initialize the BaseEngine. Args: config: Configuration object containing parameters for engine setup. """ raise NotImplementedError def init_model(self): """ Instantiate or load the model, optimizer, and learning rate scheduler. Should prepare all components necessary for training or evaluation. """ raise NotImplementedError def train_mode(self): """ Context manager entry for switching the engine and model into training mode. Usage: with engine.train_mode(): # runs in training mode """ raise NotImplementedError def eval_mode(self): """ Context manager entry for switching the engine and model into evaluation mode. Usage: with engine.eval_mode(): # runs in evaluation mode """ raise NotImplementedError def forward_backward_step(self, batch, ctx=None, forward_only=False, preprocess_fn=None, postprocess_fn=None): """ Execute a forward pass (and optional backward pass) over a batch of data. Args: batch: Raw batch data (e.g., tensors or mappings) to process. ctx: Optional context dict passed to preprocess/postprocess functions. forward_only: If True, skip gradient computation and backward pass. preprocess_fn: Function(batch, ctx) -> (inputs, ctx), applied before model call. postprocess_fn: Function(outputs, ctx) -> (predictions, ctx), applied after model call. Returns: If forward_only: (predictions, ctx) Else: (predictions, loss, ctx) """ raise NotImplementedError def optimizer_zero_grad(self): """ Zero out gradients of all parameters before starting a new backward pass. """ raise NotImplementedError def optimizer_step(self): """ Perform an optimization step to update model parameters based on accumulated gradients. Returns: grad_norm (float): The norm of the gradients before clipping or update. """ raise NotImplementedError def lr_scheduler_step(self): """ Advance the learning rate scheduler by one step. Returns: current_lr (float or list[float]): Updated learning rate(s). """ raise NotImplementedError def shard_data(self, data): """ Shard or partition data for distributed training or parallel execution. Args: data: Data structure to be sharded across devices/workers. Returns: Sharded data in the same format as input. """ raise NotImplementedError def unshard_data(self, data): """ Reconstruct or gather sharded data back to a unified format. Args: data: Sharded data structure to reconstruct. Returns: Unsharded, combined data. """ raise NotImplementedError def set_loss_fn(self, loss_fn): """ Set the loss function to be used during training. Args: loss_fn: Callable(data, predictions, ctx) -> (loss_tensor, new_ctx) """ raise NotImplementedError def to(self, device: str, model: bool = True, optimizer: bool = True): """ Move model parameters, optimizer states, or both to the specified device. Args: device: Target device identifier (e.g., "cuda" or "cpu"). model: If True, move the model. optimizer: If True, move the optimizer states. """ raise NotImplementedError def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): """ Save model, optimizer, and scheduler states to a checkpoint. Args: local_path: Local filesystem path to save checkpoint. hdfs_path: Optional HDFS path to copy checkpoint. global_step: Integer training step number for naming. max_ckpt_to_keep: Maximum number of recent checkpoints to retain. """ raise NotImplementedError def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True): """ Load model, optimizer, and scheduler states from a checkpoint. Args: local_path: Local filesystem path of the checkpoint. hdfs_path: Optional HDFS path where checkpoint is stored. del_local_after_load: Whether to delete local copy after loading. """ raise NotImplementedError ``` ### FSDPEngine Implementaion A concrete `FSDPEngine` implements all methods using PyTorch FullyShardedDataParallel, supporting all the features that FSDP DPCritic Worker support: - Multi-GPU/model sharding - Activation- and optimizer-offload - LoRA & sequence parallelism - Dynamic batch size and remove padding ### CriticWorker Implementation based on the FSDPEngine - Unchanged public API - Each role calls only BaseEngine methods (init_model, train_mode/eval_mode, forward_backward_step, etc.) - No modifications needed in existing algorithms (e.g., PPOTraining) - New roles can be plugged in identically to legacy code ## Development Plan We’ll roll this out in three gated phases, controlled by a feature-flag (`use_legacy_worker_impl`). ### Phase 1: Engine Development > Flag: use_legacy_worker_impl = True (default) > New interface under active development - Refactor Critic, Actor, Rollout, Ref to use only BaseEngine APIs - Design a hierarchical, immutable config system for engine/backends - Ensure PPO training curves and final accuracy match legacy implementation ### Phase 2: Migration > Flag: use_legacy_worker_impl = False (default) – legacy path logs a deprecation warning > All new code targets the new interface; 2–3 months of integration/stress testing - Enforce new interface for all feature work - Gather benchmarks, bug reports, and performance data ### Phase 3: Cleanup > After Phase 2 validation: - Remove legacy worker code and flags - Finalize documentation, update changelogs, close deprecation notices Please review this refactor and share any feedback or concerns! Contributions are welcome.
…est, fix math_dataset path error (volcengine#2647) ### What does this PR do? PR volcengine#1977 is a great job, I tried using the new engine and found some minor problems and add ci test for FSDPEngine. - Use newest name `gather_outputs_and_unpad` for the function `gather_outputs_and_unpad`. - Removed invalid calculations originally used for gradient accumulation (gradient accumulation has been moved to loss_fn in new engine). - Fixed misuses of two variable. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). --------- Signed-off-by: ShareLer <ShareLe@163.com> Co-authored-by: eric-haibin-lin <linhaibin.eric@gmail.com>
[Refactor] Training Engine Interface and Development Plan
Motivation
See the original RFC for background: #1371
Modernizing our training loop requires that we:
Engine
interface across FSDP/Megatron/etc backendsCurrent Implementation
This PR:
BaseEngine
class that defines a unified training‐engine interface.FSDPEngine
, a concreteBaseEngine
using PyTorch FullyShardedDataParallel.CriticWorker
based onFSDPEngine
that plugs seamlessly into existing PPO training code without any changes.Classic Training Loop with the New Interface
Detailed BaseEngine Interface
We now introduce an abstract base class,
BaseEngine
, which defines our unified training-engine interface.Key enhancements over the original RFC:
train_mode()
/eval_mode()
Context managers to control parameter and activation load/offload at the start and end of each loop.
shard_data()
/unshard_data()
APIs for partitioning and gathering data across devices or workers.
preprocess_fn
/postprocess_fn
inforward_backward_step()
Hooks to apply custom transformations before and after each micro-batch pass.
Below are the detailed signatures for each core method.
FSDPEngine Implementaion
A concrete
FSDPEngine
implements all methods using PyTorch FullyShardedDataParallel, supporting all the features that FSDP DPCritic Worker support:CriticWorker Implementation based on the FSDPEngine
Development Plan
We’ll roll this out in three gated phases, controlled by a feature-flag (
use_legacy_worker_impl
).Phase 1: Engine Development
Phase 2: Migration
Phase 3: Cleanup
Please review this refactor and share any feedback or concerns! Contributions are welcome.