Skip to content

[RFC] Support Multi-Stage Awake for RL #7009

@hebiao064

Description

@hebiao064

Contributor:

Checklist

Motivation

Reduce the peak memory usage by N GB to avoid OOM, N equals to the memory size of Model Weights.

Model Before After Improvement
QWen 32B 192GB 131.8GB 30%
Llama 8B 148.9GB 134.5GB 10%
QWen 32B Test (mem_frac = 0.9): Memory: 5.8 → 36.6 → 98.1 → 37.6 → 131.8 GB
Llama 8B Test (mem_frac = 0.9): Memory: 7.1 → 14.7 → 30.1 → 15.7 → 134.5 GB

Background

In RL Ecosystem which use colocate design like verl, we need to offload training model and load serving model & KV Cache frequently.

  • Currently SGLang is using torch_memory_saver to pause and resume.
  • torch_memory_saver is a open source repo that provided easy to use api to hack cudaMalloc and cudaFree to make sure the virtual address could be consistent after pause and resume, which is critical to ensure CUDA Graph work.
  • CUDA Graph is critical to make sure SGLang runs faster in decoding phases.

Here is the current behavior of VERL + SGLang

Image

  1. During Training, we have training model and optimizer state in the GPU Memory, and once training is done, we will offload optimizer state to cpu and keep the model weights in GPU, which is needed in Update Weight.
  2. During Update Weight, we awake the SGLang engine, so those paused memory of Model Weights and KV Cache will come back. Then we update model from training model to serving model on the fly using the api: update_weights_in_tensor
  3. After Model being updated, we delete the training model from GPU Memory.

Above design works pretty well so far, however, this would waste a big chunk of GPU Memory during rollout, which could cause a few issues we've seen so far:

  • Small KV Cache: We need to use relative lower number of mem fraction ratio (e.g: 0.6), hence our KV Cache has less tokens. Given KV Cache has less tokens, we will hit RuntimeError: Prefill out of memory. Try to lower your batch size. when we try prefill large number of requests.
  • Out of Memory: If we use mem fraction ratio 0.8 and run RL for 32B model on 8 H100, it will OOM during update weight

Challenge

  • torch_memory_saver currently only supports Singleton, hence SGLang will pause and resume KV Cache + Weights together, they are treated as the same group of memory controlled by the singleton torch_memory_saver instance

Proposal

Image

  1. During Training, we do the same
  2. During Update Weight Stage 1, we awake the model weights from SGLang and then update weights
  3. During Update Weight Stage 2, we delete the training model weights from GPU Memory
  4. Awake the SGLang's KV Cache

Image

Benefit

With above feature, we can train larger model with same GPU, we can also make training/rollout more efficient given we can allocate larger KV Cache

Option 1: Use two torch_memory_saver

Option 2: Keep using Singleton and provide tag based pause/resume

Related resources

No response

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions