-
Notifications
You must be signed in to change notification settings - Fork 2.3k
[rollout] feat: support reorder rollout for tackling long-tail generation problem #2200
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as we perviously discussed with haibin, the streaming partial-rollout feature need to be placed as a recipe to keep the main trainer as simple as possible
ee56567
to
dd070e3
Compare
b99485b
to
152c8e2
Compare
Impressive work ! Curious does this optimization work best with high max_response_length settings and high variance in actual response lengths? Would the gains be smaller for other tasks? |
yes it might show less effect when max_response_length is rather small. Our goal is to solve the long-tailed problem, so assuming that each batch contains one long-tailed sample and the inference time is t_max, the whole dataset contains num_b batches and the normal sample time is t_mean, then.
for max_response_length is kind small, that makes t_max smaller and for low variance, that makes some turns don't contains any long tail sample |
4703452
to
9b12dea
Compare
f27286c
to
d580644
Compare
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a streaming rollout mechanism to improve performance for on-policy algorithms, which is a valuable addition. The implementation is comprehensive, including a new StreamScheduler
, asynchronous workers, and cancellation logic.
I've found a couple of issues: one critical bug related to an incorrect configuration access that will cause a runtime error, and one high-severity issue in a utility function that could lead to incorrect behavior if used in a more general context. Please address these points to ensure the stability and correctness of this new feature.
model_name: str | ||
messages: list[dict[str, str]] | ||
sampling_params: dict[str, Any] | ||
agent_name: np.ndarray |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is the name a ndarray?
@@ -327,6 +340,28 @@ async def generate(self, prompt_ids: list[int], sampling_params: dict[str, Any], | |||
|
|||
return final_res.outputs[0].token_ids | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wuxibin89 please help review this part
@@ -0,0 +1,65 @@ | |||
set -x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add an end2end test to CI similar to https://github.com/volcengine/verl/blob/main/.github/workflows/e2e_dapo.yml
574202f
to
9482b07
Compare
581022e
to
0c8463b
Compare
0c8463b
to
4f6c8af
Compare
@@ -338,6 +352,10 @@ async def generate( | |||
# TODO: vllm image_data surportting like sglang async server | |||
if image_data is not None: | |||
raise NotImplementedError("image_data is not supported for vLLM rollout") | |||
with ExitStack() as stack: | |||
self.req_result[request_id] = None | |||
stack.callback(lambda: self.req_result.pop(request_id, None)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This ExitStack
seems to be redundant.
return generation_handle.result() | ||
return None | ||
|
||
async def cancel(self, request_id: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use AsyncLLM.abort
?
async def cancel_req(self, request_id: str): | ||
logger.debug(f"cancel request {request_id} from cancel_req") | ||
if request_id in self.active_req: | ||
self._engine.tokenizer_manager.abort_request(request_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's safe to abort request which is already done, hence we can eliminate self.active_req
.
request_sampling_params.update(sampling_params) | ||
output = await self._handle_engine_generate(prompt_ids, request_sampling_params, image_data=image_data) | ||
return output["output_ids"] | ||
with ExitStack() as stack: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand why we need ExitStack
?
@@ -1772,6 +1772,11 @@ async def generate( | |||
ret = await self.rollout.generate(prompt_ids, sampling_params, request_id, image_data=image_data) | |||
return ret | |||
|
|||
@register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False) | |||
async def cancel_req(self, request_id: str) -> list[int]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better naming to abort
to align with inference engine.
|
||
|
||
# copy from vllm | ||
def with_cancellation(handler_func): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where's this function used?
@@ -297,6 +303,14 @@ async def init_engine(self): | |||
tool_parser=config.multi_turn.format, # hermes, llama3_json, ... | |||
) | |||
|
|||
async def _force_log(stat_log_interval=10): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why we need this?
if partial_rollout: | ||
# response_id = [prompt_ids-prompt_token_ids,response_ids] | ||
response_ids = prompt_ids[len(token_ids) :] + response_ids | ||
# prompt as original | ||
prompt_ids = prompt_ids[: len(token_ids)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noticed a bug here, let's say
-
Original prompt:
"Hello, how are you?"
→ tokens[1, 2, 3, 4]
-
Saved partial response:
"I'm doing"
→ tokens[5, 6, 7]
-
New generation:
"well, thanks!"
→ tokens[8, 9, 10]
We have
original_prompt_ids = [1, 2, 3, 4]
token_ids = [5, 6, 7]
prompt_ids = [1, 2, 3, 4] + [5, 6, 7] = [1, 2, 3, 4, 5, 6, 7]
response_ids = [8, 9, 10]
Current logic gives us:
response_ids = prompt_ids[len(token_ids):] + response_ids
# = [1,2,3,4,5,6,7][3:] + [8,9,10]
# = [4,5,6,7] + [8,9,10]
# = [4,5,6,7,8,9,10]
prompt_ids = prompt_ids[:len(token_ids)]
# = [1,2,3,4,5,6,7][:3]
# = [1,2,3]
We can see the response_ids will include the original prompt token [4] and prompt_ids are truncated.
Maybe we can do the following instead?
original_prompt_len = len(prompt_ids) - len(token_ids)
response_ids = prompt_ids[original_prompt_len :] + response_ids
prompt_ids = prompt_ids[: original_prompt_len]
|
||
|
||
# copy from agent-loop | ||
def get_agent_loop_class(agent_name: str) -> type[AgentLoopBase]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this method used?
@@ -35,20 +35,33 @@ def __init__(self, *args, **kwargs): | |||
|
|||
async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput: | |||
messages = list(kwargs["raw_prompt"]) | |||
token_ids = kwargs.get("token_ids", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we only support partial_rollout for SingleTurnAgentLoop? Can we up-level this partial_rollout support so all agent loop classes get this for free?
Specifically for ToolAgentLoop, we may need to add additional logics to handle tool use state, cancellation and recovery. Can follow up changes if this is out of scope for this PR.
reward_tensor = self.rm_wg.compute_rm_score(batch) | ||
batch = batch.union(reward_tensor) | ||
|
||
if self.config.reward_model.launch_reward_fn_async: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Idea on this, we can start reward computation at sample level right after generation without waiting on a full batch have been formed, although async to old_log_prob can hide most of the wait time on reward computation for most cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahh already a PR on this: #3055
accordding to this: |
Background
currently, rollout suffer from long tail generations, means for one batch generation, some samples might be too long, become the struggler.

Main Idea
To solve this, we implement StreamScheduler. This strategy will keep fetching data from data iterator and send it to serving engine until the stop terms are met. this strategy works since we basically select the shortest generation samples to fill this batch and postpone those long tail samples .(even though those will be sent to engine, but eventually they will be 'dropped' since other prefetched samples will be in the output buffer faster then them)
couple things we might need to be careful with:
with the result of partial generations. drop/save kv cache/ save result/ staleness factor, might be considered. we also need to worry about the requeue pattern, since for GRPO, we need to requeue n-samples all-together.
BenchMark
Rollout Bench
E2E Bench
dataset: GSM8k
max_output_length = 16k
model: qwen2-7b
8*H20,dp=8

Rollout throughput: baseline vs stream = 0.11 vs 0.05. Overall, there’s approximately a 2× improvement, and notably, stream mode shows almost no long-tail samples.