Skip to content

Conversation

Irvingwangjr
Copy link
Collaborator

@Irvingwangjr Irvingwangjr commented Jun 25, 2025

Background

currently, rollout suffer from long tail generations, means for one batch generation, some samples might be too long, become the struggler.
image

Main Idea

To solve this, we implement StreamScheduler. This strategy will keep fetching data from data iterator and send it to serving engine until the stop terms are met. this strategy works since we basically select the shortest generation samples to fill this batch and postpone those long tail samples .(even though those will be sent to engine, but eventually they will be 'dropped' since other prefetched samples will be in the output buffer faster then them)

couple things we might need to be careful with:

  • host memory issue, there might be a chances that the data fetcher will fetch expect_batch_size*n size of prompts into memory, since it tries to maximize the engine utilization by feeding data as much as possible. we need to control the maximum number of inflight req (data in global queue, tool-calling state, serving state) to avoid loading too much req in memory.
  • abort pattern: when we hit those stop terms, we should cancel those inflight req. the question is how should we deal
    with the result of partial generations. drop/save kv cache/ save result/ staleness factor, might be considered. we also need to worry about the requeue pattern, since for GRPO, we need to requeue n-samples all-together.
  • re-generate batch/gen_batch for training.
  • intergation for pytorch dataloader and asyncio.

image

BenchMark

Rollout Bench

  • dataset: Eurus-2-RL-Data,choose code samples, first 5000 data samples.
  • max_output_length = 16k
  • model: qwen2-7b
  • 4*H20
  • batch_size = 1024.
    image
  • There are a total of 5 mini-batches. The total time cost is: stream = 619.75s vs native = 1377.73s, showing a 2.22× speedup.
  • Excluding the last batch and focusing on the earlier, more stable mini-batches, the time decreased from around 270s to 70s.

E2E Bench

  • dataset: GSM8k

  • max_output_length = 16k

  • model: qwen2-7b

  • 8*H20,dp=8
    image

  • Rollout throughput: baseline vs stream = 0.11 vs 0.05. Overall, there’s approximately a 2× improvement, and notably, stream mode shows almost no long-tail samples.

image
  • the trend of max response length also verify that the baseline are slower because of the longest sample.

image

  • There is no significant difference in metrics on the test set.

image

  • Convergence on the training set is smooth with no issues.

@Irvingwangjr Irvingwangjr changed the title feat: support streaming rollout for on-policy algorithm [WIP]feat: support streaming rollout for on-policy algorithm Jun 25, 2025
Copy link
Collaborator

@chenhaiq chenhaiq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as we perviously discussed with haibin, the streaming partial-rollout feature need to be placed as a recipe to keep the main trainer as simple as possible

@Irvingwangjr Irvingwangjr force-pushed the feat/async-tools branch 2 times, most recently from ee56567 to dd070e3 Compare June 27, 2025 09:54
@eric-haibin-lin eric-haibin-lin self-assigned this Jul 4, 2025
@Irvingwangjr Irvingwangjr force-pushed the feat/async-tools branch 2 times, most recently from b99485b to 152c8e2 Compare July 10, 2025 08:29
@PrinsYin
Copy link
Contributor

Impressive work ! Curious does this optimization work best with high max_response_length settings and high variance in actual response lengths? Would the gains be smaller for other tasks?

@Irvingwangjr
Copy link
Collaborator Author

Impressive work ! Curious does this optimization work best with high max_response_length settings and high variance in actual response lengths? Would the gains be smaller for other tasks?

yes it might show less effect when max_response_length is rather small.

Our goal is to solve the long-tailed problem, so assuming that each batch contains one long-tailed sample and the inference time is t_max, the whole dataset contains num_b batches and the normal sample time is t_mean, then.

  1. baseline time: t_max * b
  2. stream-mode time: t_mean * (b-1) + t_max,
  3. So for b >> 1, the average rollout time for this mode is t_mean, and the baseline is t_max.

for max_response_length is kind small, that makes t_max smaller and for low variance, that makes some turns don't contains any long tail sample

@Irvingwangjr Irvingwangjr changed the title [WIP]feat: support streaming rollout for on-policy algorithm [Rollout]feat: support streaming rollout for on-policy algorithm Jul 14, 2025
@Irvingwangjr Irvingwangjr force-pushed the feat/async-tools branch 2 times, most recently from 4703452 to 9b12dea Compare July 14, 2025 13:23
@Irvingwangjr Irvingwangjr changed the title [Rollout]feat: support streaming rollout for on-policy algorithm feat: support streaming rollout for on-policy algorithm Jul 14, 2025
@Irvingwangjr Irvingwangjr changed the title feat: support streaming rollout for on-policy algorithm [rollout] feat: support streaming rollout for on-policy algorithm Jul 14, 2025
@Irvingwangjr Irvingwangjr force-pushed the feat/async-tools branch 2 times, most recently from f27286c to d580644 Compare July 15, 2025 03:34
@eric-haibin-lin
Copy link
Collaborator

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a streaming rollout mechanism to improve performance for on-policy algorithms, which is a valuable addition. The implementation is comprehensive, including a new StreamScheduler, asynchronous workers, and cancellation logic.

I've found a couple of issues: one critical bug related to an incorrect configuration access that will cause a runtime error, and one high-severity issue in a utility function that could lead to incorrect behavior if used in a more general context. Please address these points to ensure the stability and correctness of this new feature.

model_name: str
messages: list[dict[str, str]]
sampling_params: dict[str, Any]
agent_name: np.ndarray
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is the name a ndarray?

@@ -327,6 +340,28 @@ async def generate(self, prompt_ids: list[int], sampling_params: dict[str, Any],

return final_res.outputs[0].token_ids

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wuxibin89 please help review this part

@@ -0,0 +1,65 @@
set -x
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Irvingwangjr Irvingwangjr force-pushed the feat/async-tools branch 2 times, most recently from 574202f to 9482b07 Compare July 21, 2025 05:33
@chenhaiq chenhaiq requested review from eric-haibin-lin and removed request for SwordFaith, zhaochenyang20 and tongyx361 August 12, 2025 09:49
@@ -338,6 +352,10 @@ async def generate(
# TODO: vllm image_data surportting like sglang async server
if image_data is not None:
raise NotImplementedError("image_data is not supported for vLLM rollout")
with ExitStack() as stack:
self.req_result[request_id] = None
stack.callback(lambda: self.req_result.pop(request_id, None))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This ExitStack seems to be redundant.

return generation_handle.result()
return None

async def cancel(self, request_id: str):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use AsyncLLM.abort?

async def cancel_req(self, request_id: str):
logger.debug(f"cancel request {request_id} from cancel_req")
if request_id in self.active_req:
self._engine.tokenizer_manager.abort_request(request_id)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's safe to abort request which is already done, hence we can eliminate self.active_req.

request_sampling_params.update(sampling_params)
output = await self._handle_engine_generate(prompt_ids, request_sampling_params, image_data=image_data)
return output["output_ids"]
with ExitStack() as stack:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why we need ExitStack?

@@ -1772,6 +1772,11 @@ async def generate(
ret = await self.rollout.generate(prompt_ids, sampling_params, request_id, image_data=image_data)
return ret

@register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False)
async def cancel_req(self, request_id: str) -> list[int]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better naming to abort to align with inference engine.



# copy from vllm
def with_cancellation(handler_func):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where's this function used?

@@ -297,6 +303,14 @@ async def init_engine(self):
tool_parser=config.multi_turn.format, # hermes, llama3_json, ...
)

async def _force_log(stat_log_interval=10):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we need this?

Comment on lines +59 to +63
if partial_rollout:
# response_id = [prompt_ids-prompt_token_ids,response_ids]
response_ids = prompt_ids[len(token_ids) :] + response_ids
# prompt as original
prompt_ids = prompt_ids[: len(token_ids)]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noticed a bug here, let's say

  • Original prompt: "Hello, how are you?" → tokens [1, 2, 3, 4]

  • Saved partial response: "I'm doing" → tokens [5, 6, 7]

  • New generation: "well, thanks!" → tokens [8, 9, 10]

We have

original_prompt_ids = [1, 2, 3, 4]
token_ids = [5, 6, 7]
prompt_ids = [1, 2, 3, 4] + [5, 6, 7] = [1, 2, 3, 4, 5, 6, 7]
response_ids = [8, 9, 10]

Current logic gives us:

response_ids = prompt_ids[len(token_ids):] + response_ids
# = [1,2,3,4,5,6,7][3:] + [8,9,10]  
# = [4,5,6,7] + [8,9,10]
# = [4,5,6,7,8,9,10]

prompt_ids = prompt_ids[:len(token_ids)]
# = [1,2,3,4,5,6,7][:3]
# = [1,2,3]  

We can see the response_ids will include the original prompt token [4] and prompt_ids are truncated.

Maybe we can do the following instead?

original_prompt_len = len(prompt_ids) - len(token_ids)
response_ids = prompt_ids[original_prompt_len :] + response_ids
prompt_ids = prompt_ids[: original_prompt_len]              



# copy from agent-loop
def get_agent_loop_class(agent_name: str) -> type[AgentLoopBase]:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this method used?

@@ -35,20 +35,33 @@ def __init__(self, *args, **kwargs):

async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
messages = list(kwargs["raw_prompt"])
token_ids = kwargs.get("token_ids", None)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we only support partial_rollout for SingleTurnAgentLoop? Can we up-level this partial_rollout support so all agent loop classes get this for free?

Specifically for ToolAgentLoop, we may need to add additional logics to handle tool use state, cancellation and recovery. Can follow up changes if this is out of scope for this PR.

reward_tensor = self.rm_wg.compute_rm_score(batch)
batch = batch.union(reward_tensor)

if self.config.reward_model.launch_reward_fn_async:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Idea on this, we can start reward computation at sample level right after generation without waiting on a full batch have been formed, although async to old_log_prob can hide most of the wait time on reward computation for most cases.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh already a PR on this: #3055

@shuowoshishui
Copy link

accordding to this:
Fine-Grained Request-Level Parallelism:
In current rollout implementations, batch-level inputs are typically split into request-level units for inference, and later aggregated back into a training batch. This request-level granularity opens up opportunities for smarter scheduling.

Since we have implemented fine-grained inference based on such principles, when will we implement a scheme where training starts asynchronously as soon as we have enough micro_batch size, advancing the training start time?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants