-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Description
Checklist
- 1. If the issue you raised is not a feature but a question, please raise a discussion at https://github.com/sgl-project/sglang/discussions/new/choose Otherwise, it will be closed.
- 2. Please use English, otherwise it will be closed.
Motivation
TL;DR: Introducing several features that would be beneficial for integrating SGLang into veRL and may also be beneficial for other Post-Training frameworks.
Provide an inference script that is started by torchrun (support SPMD)
Currently, the offline inference script is launched by sgl.Engine
. Internally, it spawns multiple Scheduler
.
With torchrun
, the Scheduler
is launched by torchrun
and the tp_rank can be obtained from the environ.
In veRL, the Data Parallel dimension is managed by our WorkerGroup
and the dp_rank of each Scheduler should be None.
More specifically, if the current WorkerGroup
has 8 GPUs while we set the Rollout TP size to 2. All the GPUs in this WorkerGroup
will build the distributed world and the generation engine and training engine will construct its own TP/PP groups. veRL's data_protocol
will partition and dispatch the prompts to each TP/PP group without the generation engine is aware of the DP dimension.
A general picture of a torchrun script that can simulate the HybridEngine behavior.
# build distributed world
local_rank, rank, world_size = initialize_global_process_group()
# build device mesh for training engine.
device_mesh = init_device_mesh('cuda', mesh_shape=(world_size,), mesh_dim_names=['fsdp'])
fsdp_model = FSDP(actor_model,
...
device_mesh=device_mesh)
FSDP.set_state_dict_type(fsdp_model,
state_dict_type=StateDictType.SHARDED_STATE_DICT,
state_dict_config=ShardedStateDictConfig())
# get sharded model state dict
state_dict = fsdp_model.state_dict()
# [Optional] build device mesh for inference engine
gen_device_mesh = init_device_mesh('cuda', mesh_shape=(2, 4), mesh_dim_names=['dp', 'tp'])
# build inference engine
inference_engine = SGLEngine(model_hf_config=actor_model_config,
tensor_parallel_size=tensor_model_parallel_size,
pipeline_parallel_size=pipeline_parallel_size, # if any
enforce_eager=False, # use cuda graph with offload KVCache and weight
dtype='bfloat16',
load_format='dummy_dtensor', # initialize dummy weight
gpu_memory_utilization=0.1,
trust_remote_code=True)
# [Optional] update parallel state in SGLang for 3D-HybridEngine
inference_engine.update_parallel_state(TP=device_mesh["tp"])
# sync weights between actor and rollout, support several format: DTensor and Megatron (sharded)
inference_engine.sync_model_weights(actor_weights=state_dict, load_format='dtensor')
# generate sequence, it would be better if the output is a list of Tensor not list of list[str]
outputs = lnference_engine.generate(prompt_token_ids=idx_list, sampling_params=sampling_params, use_tqdm=False)
# offload kvcache after generation
inference_engine.free_kvcache() # inference_engine.init_kvcache()
# offload model
inference_engine.offload_model_weights() # inference_engine.load_model_weights(), we can simply re-init them
Expose an API that can load weights in TP/PP format
inference_engine.sync_model_weights(actor_weights=state_dict, load_format='dtensor')
in the above code.
We may need two different load formats with different weight loaders:
- dtensor: The SGLang model weight is sharded, our state_dict is sharded in different ways but we gather them layer-by-layer and feed them into the SGLang weight loader for synchronization.
- megatron sharded: The SGLang model weight is sharded, verl hybrid engine prepares a state_dict that is identical to SGLang's sharded weight. Therefore, the SGLang model can directly copy the weights in place without any further sharding.
Expose an API that can free/re-init kv cache, and offload/load model weights
inference_engine.free_kvcache()
and inference_engine.init_kvcache()
; inference_engine.offload_model_weights()
and inference_engine.load_model_weights()
It would be better to support CUDAGraph although we offload kvcache and model weights. Reference: #2542
Disable detokenize during generation.
In RL training, we only need token_ids in most training scenarios and we can perform batch detokenize when we really need tokens. We don't care about the ITL metric.
After being disabled, we can check whether there are any opportunities to improve the throughput
3D-HybridEngine parallel state construction (TP/PP group generation logic should be different from Megatron-LM when using 3D-HybridEngine)
With our 3D-HybridEngine design in paper and code, the grouping strategy for TP/PP in SGLang shall be aware of the TP/PP size in training framework.
We consider that SGLang is not necessarily to be aware of the TP/PP size in the training framework.
So, we can build the TP/PP groups for SGLang before SGLang initialization and then update these TP/PP groups to the SGLEngine. See [Optional] in the above code.
Output post-process to torch.Tensor (token_ids).
A small feature, if not supported, we can implement some post-process in veRL. No worries.
Related resources
No response