-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Open
Labels
Description
- Prerequisite: Make sure the LLM Inference framework can be launched following the SPMD style. For example, the LLM inference script can be launched by
torchrun --standalone --nproc=8 offline_inference.py
- A Rollout class: Build a
xxx_rollout.py
script similar tovllm_rollout.py
. In this file, define axxxRollout
class that inherits fromBaseRollout
.- This class should have a
generate_sequence
API that accepts a batch ofinput_ids
,response_masks
, andposition_ids
from theDataProto
as input. Theself.inference_engine
(e.g.,LLMEngine
in vLLM) is responsible for performing auto-regressive generation and outputting a batch of responses. These responses should then be concatenated withinput_ids
, and theresponse_masks
andposition_ids
should be reconstructed accordingly.
- This class should have a
- ShardingManager Classes for Weight Synchronization with Training Frameworks: Create files named
fsdp_xxx.py
andmegatron_xxx.py
, similar tofsdp_vllm.py
andmegatron_vllm.py
. These files should defineXXXShardingManager
classes (i.e., HybridEngine) that handle weight sharding between the training and inference frameworks.- In
megatron_vllm.py
, we define anAllGatherPPModel
class to collect weights across the pipeline parallelism dimension. The parameters stored in thememory_buffers
ofAllGatherPPModel
will be used to synchronize the weights with the models in the vLLM rollout.
- In
- Weight loading issues: It may be necessary to provide specific weight loaders for transferring weights between different LLM Inference and Training backends for each model. This is similar to the
dtensor_weight_loader.py
andmegatron_weight_loader.py
files in vLLM.