-
Notifications
You must be signed in to change notification settings - Fork 2.9k
[feat] support minimum token load balance in dp attention #7379
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @WANG-GH, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request enhances the data parallel attention mechanism by introducing a new, more intelligent load balancing strategy. The core idea is to minimize the total token load across DP groups by considering both currently processed and in-flight requests, leveraging shared memory for efficient state synchronization. This change is expected to improve overall serving performance and resource utilization.
Highlights
- New Load Balancing Method: I've introduced a new load balancing method for Data Parallel (DP) attention called
minimum_tokens
. This method aims to distribute incoming requests more efficiently by dispatching them to the DP group with the lowest total token load. - Load Calculation Logic: The 'minimum_tokens' method calculates the total load for each DP group by summing two components:
holding_tokens
(tokens currently being processed by the scheduler) andonfly_req
(requests dispatched by the controller but not yet accepted by the scheduler). This provides a more comprehensive view of each group's real-time load. - Shared Memory for Load Data: To facilitate real-time load balancing, I've implemented a shared memory mechanism (
DPBalanceMeta
). The TP0 node (dispatcher) gathers load data from all DP groups via this shared memory, allowing it to make informed dispatching decisions. A uniquedp_balance_id
is added to each request to track its status in the 'on-the-fly' queue. - Performance Improvements: Initial benchmarks show that using the
minimum_tokens
load balancing method leads to improvements in both Time-to-First-Token (TTFT) and Total Batch Throughput (TBT) compared to the default round-robin method, indicating better overall system performance. - Shared Memory Cleanup: I've added explicit cleanup for the POSIX shared memory segments (
sglang_dp_balance_onfly_info
,sglang_dp_balance_local_tokens
) during process termination. This prevents shared memory leaks, which could otherwise cause issues with subsequent server launches.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new load balancing strategy, minimum_tokens
, for Data Parallel (DP) attention. This method aims to improve performance by dispatching requests to the DP group with the fewest combined holding and in-flight tokens. Key changes include:
- A new
DPBalanceMeta
class to manage shared memory for tracking token loads (onfly_info
andlocal_tokens
) across DP workers, including initialization and cleanup logic. - Modifications to the
DataParallelController
to incorporate the new scheduling strategy, initializeDPBalanceMeta
, and ensure shared memory is released. - Updates to the
Scheduler
to interact withDPBalanceMeta
, gather load information from DP workers, and update shared memory via the TP0 rank. - Additions to
io_struct
to includedp_balance_id
in request objects for tracking. - Enhanced
kill_process_tree
utility to clean up the newly introduced shared memory segments.
The core logic for load balancing and shared memory management appears sound, with proper locking mechanisms. However, there's a critical issue in how local_tokens
shared memory is handled (missing size header and assertion), and some areas could benefit from improved clarity with constants and variable naming. The performance benchmarks provided in the PR description show improvements with the new method.
def get_shared_local_tokens(self) -> List[int]: | ||
shm = shared_memory.SharedMemory(name=self.shm_name_local_tokens) | ||
serialized_data = bytes(shm.buf) | ||
worker_onfly_data = pickle.loads(serialized_data) | ||
shm.close() | ||
return worker_onfly_data | ||
|
||
def set_shared_local_tokens(self, data: List[int]): | ||
serialized_data = pickle.dumps(data) | ||
data_size = len(serialized_data) | ||
|
||
shm = shared_memory.SharedMemory(name=self.shm_name_local_tokens) | ||
shm.buf[:data_size] = serialized_data | ||
|
||
shm.close() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a critical issue in how local_tokens
shared memory is handled:
set_shared_local_tokens
writes the pickled data but does not prefix it with its size, unlikeset_shared_onfly_info
.get_shared_local_tokens
reads the entire shared memory buffer (bytes(shm.buf)
) and attempts to unpickle it. If the actual pickled data is smaller thanself.local_tokens_size
,pickle.loads
will try to interpret trailing garbage data, leading to potential errors or incorrect deserialization.set_shared_local_tokens
does not assert that theserialized_data
fits withinself.local_tokens_size
.
To fix this, set_shared_local_tokens
should be modified to first write the size of the serialized data (using struct.pack
like in set_shared_onfly_info
), and get_shared_local_tokens
should read this size to correctly slice the buffer before unpickling. An assertion for data size should also be added to set_shared_local_tokens
.
self.onfly_info_size = ( | ||
512 * num_workers * 8 | ||
) # max_onfly_req_per_worker * num_workers * dByte | ||
self.local_tokens_size = num_workers * 8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The calculation for onfly_info_size
and local_tokens_size
uses a magic number 8
(commented as dByte
).
-
For
onfly_info_size
,512
representsmax_onfly_req_per_worker
. The* 8
seems to be an estimate for the pickled size of dictionary entries. While the assertion inset_shared_onfly_info
provides safety, consider makingMAX_ONFLY_REQ_PER_WORKER = 512
andESTIMATED_BYTES_PER_ENTRY = 8
(or a more descriptive name if8
has a specific meaning) as named constants for better readability and maintainability. -
For
local_tokens_size = num_workers * 8
, this calculation might be too small for storing a pickled list ofnum_workers
integers, especially considering Python's pickle overhead. For example,pickle.dumps([0]*16)
is 50 bytes, not16*8=128
. More importantly,set_shared_local_tokens
currently lacks an assertion to check if the pickled data fits withinlocal_tokens_size
. This, combined with the issue inget_shared_local_tokens
(see separate comment), is problematic.
# 2. write the new onfly info to the shm | ||
self.balance_meta.set_shared_onfly_info(onfly_info) | ||
|
||
logger.info(f"dp workers {local_tokens=}, {onfly_info=}, {target_worker=}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This log message uses logger.info
and includes potentially large data structures (local_tokens
, onfly_info
). If requests are frequent, this could lead to excessive logging and performance overhead. Consider changing this to logger.debug
or making it conditional, for example, logging only every N requests or if a specific debug flag is enabled.
init_onfly_req = [{} for _ in range(num_workers)] | ||
self.set_shared_local_tokens(init_local_tokens) | ||
self.set_shared_onfly_info(init_onfly_req) | ||
self.shm1.name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -1757,6 +1776,91 @@ def prepare_dp_attn_batch(self, local_batch: ScheduleBatch): | |||
deepep_mode=DeepEPMode[self.server_args.deepep_mode], | |||
) | |||
|
|||
def handle_dp_balance_data(self, local_batch: ScheduleBatch): | |||
def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parameter holding_tokens_list
in gather_dp_balance_info
is misleading. Based on its usage in handle_dp_balance_data
(line 1854-1855), it's a single integer representing the current worker's holding tokens, not a list.
Later, within gather_dp_balance_info
(line 1814), holding_tokens_list
is re-assigned to be a list of holding tokens gathered from all workers if self.tp_rank == 0
.
This dual meaning and misnaming can cause confusion. Consider renaming the parameter to something like current_worker_holding_tokens: int
and the locally gathered list to all_workers_holding_tokens: List[int]
.
assert len(recv_list) <= 511, ( | ||
"The number of requests received this round is too large. " | ||
"Please increase gather_tensor_size and onfly_info_size." | ||
) | ||
|
||
gather_tensor_size = 512 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The values 511
and 512
(for gather_tensor_size
) are used here and seem related to max_onfly_req_per_worker
defined implicitly in DPBalanceMeta
. It would be better to define these as named constants, possibly in data_parallel_meta.py
or a shared constants module, and import them here. This improves readability and makes it easier to update if the underlying limits change.
def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]: | ||
"""gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance""" | ||
recv_list = self.recv_dp_balance_id_this_term | ||
assert len(recv_list) <= 511, ( | ||
"The number of requests received this round is too large. " | ||
"Please increase gather_tensor_size and onfly_info_size." | ||
) | ||
|
||
gather_tensor_size = 512 | ||
# recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids | ||
recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32) | ||
recv_tensor[0] = holding_tokens_list | ||
recv_tensor[1] = len( | ||
recv_list | ||
) # The first element is the length of the list. | ||
recv_tensor[2 : len(recv_list) + 2] = torch.tensor( | ||
recv_list, dtype=torch.int32 | ||
) | ||
|
||
if self.tp_rank == 0: | ||
gathered_list = [ | ||
torch.zeros(gather_tensor_size, dtype=torch.int32) | ||
for _ in range(self.balance_meta.num_workers) | ||
] | ||
else: | ||
gathered_list = None | ||
|
||
torch.distributed.gather( | ||
recv_tensor, gathered_list, group=self.tp_cpu_group | ||
) | ||
|
||
gathered_id_list_per_worker = None | ||
if self.tp_rank == 0: | ||
gathered_id_list_per_worker = [] | ||
holding_tokens_list = [] | ||
for tensor in gathered_list: | ||
holding_tokens_list.append(tensor[0].item()) | ||
list_length = tensor[1].item() | ||
gathered_id_list_per_worker.append( | ||
tensor[2 : list_length + 2].tolist() | ||
) | ||
|
||
return gathered_id_list_per_worker, holding_tokens_list |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return type hint Union[None, List[List[int]]]
for gather_dp_balance_info
is incorrect. The function actually returns a tuple: (gathered_id_list_per_worker, holding_tokens_list)
.
Based on the logic:
- If
self.tp_rank != 0
, it returns(None, input_argument_holding_tokens)
. - If
self.tp_rank == 0
, it returns(List[List[int]], List[int])
(where the second list is the gathered holding tokens for all workers).
A more accurate type hint would be Tuple[Optional[List[List[int]]], List[int]]
.
def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]: | |
"""gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance""" | |
recv_list = self.recv_dp_balance_id_this_term | |
assert len(recv_list) <= 511, ( | |
"The number of requests received this round is too large. " | |
"Please increase gather_tensor_size and onfly_info_size." | |
) | |
gather_tensor_size = 512 | |
# recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids | |
recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32) | |
recv_tensor[0] = holding_tokens_list | |
recv_tensor[1] = len( | |
recv_list | |
) # The first element is the length of the list. | |
recv_tensor[2 : len(recv_list) + 2] = torch.tensor( | |
recv_list, dtype=torch.int32 | |
) | |
if self.tp_rank == 0: | |
gathered_list = [ | |
torch.zeros(gather_tensor_size, dtype=torch.int32) | |
for _ in range(self.balance_meta.num_workers) | |
] | |
else: | |
gathered_list = None | |
torch.distributed.gather( | |
recv_tensor, gathered_list, group=self.tp_cpu_group | |
) | |
gathered_id_list_per_worker = None | |
if self.tp_rank == 0: | |
gathered_id_list_per_worker = [] | |
holding_tokens_list = [] | |
for tensor in gathered_list: | |
holding_tokens_list.append(tensor[0].item()) | |
list_length = tensor[1].item() | |
gathered_id_list_per_worker.append( | |
tensor[2 : list_length + 2].tolist() | |
) | |
return gathered_id_list_per_worker, holding_tokens_list | |
def gather_dp_balance_info(current_worker_holding_tokens: int) -> Tuple[Optional[List[List[int]]], List[int]]: |
# while onfly refers to the requests dispatched by the dispatcher but not yet received by the scheduler. | ||
onfly_info = self.balance_meta.get_shared_onfly() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might be more formal to name these as "on_the_fly" or "in_flight"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, we can first discuss the current implementation of the algorithm, and then I’ll make a unified update accordingly.
I see, previously it only uses round robin without checking token count |
I think you should clean the commit history🤣 |
I've just cleaned up the commit history and resolved the conflicts with the main branch. Could you please review this PR? |
Plz fix the lint, you could ref to contribution guide. |
Hi, i finished the lint error, could you please trigger the ci? |
"Please increase gather_tensor_size and onfly_info_size." | ||
) | ||
|
||
gather_tensor_size = 512 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add explanation for this value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it means the maximum size of the tensor used for gathering data.
def get_next_global_balance_id() -> int: | ||
INT32_MAX = 2147483647 | ||
current_id = self.global_balance_id | ||
self.global_balance_id = (self.global_balance_id + 1) % INT32_MAX |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain the meaning of this variable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, this variable corresponds to the balance_id in TokenizedGenerateReqInput.
We use it to to control the number of onfly tokens (requests dispatched to workers but not yet received).
def __init__(self, num_workers: int): | ||
self.num_workers = num_workers | ||
self._manager = mp.Manager() | ||
self.mutex = self._manager.Lock() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking if we could abstract some methods to python/sglang/srt/utils.py
and python/sglang/srt/distributed/parallel_state.py
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, i moved it to the manager/utils.py
I tested on 2*node H20 with 4096 reqs
rr:
minimum token:
|
I've been following your PR for a long time. I was wondering when it will be merged into the main repository. Do you have a plan for that? |
Thank you for the reply! I also hope to merge it into the main branch as soon as possible. Over the past few weeks, I've been actively communicating with Qiaolin and Cheng Wan from the SGLang community, and they have provided many great suggestions for this PR. We are currently confirming with the PD separation team to check if there are any conflicts between our designs. |
I've read through your PR thoroughly and have two points I'd like to confirm with you: The first question: There is a The second question is about the implementation of |
@WANG-GH Glad to know this is actively developed. Currently, the Router has the load balancer function which |
Thanks for ask. For the first question: I hadn't noticed the get_load function before. When calculating the load, this function, combined with the current batch size, can better accommodate PD separation. For the second question: IPC communication is more suitable for the producer-consumer model, but the current controller needs to frequently check the global instance's status. This programming paradigm is more like using locks to maintain a global state, which doesn't align well with the producer-consumer model. Changing to IPC would be quite awkward. |
I hadn't noticed the get_load function before. When calculating the load, this function, combined with the current batch size, can better accommodate PD separation. I will refactor the |
thx for reply, i see。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx for your work on this feature, it helps me a lot.
with self.balance_meta.mutex: | ||
# 1. local_tokens represents the tokens currently inferring on the worker, | ||
# while onfly refers to the requests dispatched by the dispatcher but not yet received by the scheduler. | ||
onfly_info = self.balance_meta.get_shared_onfly() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I’m wondering how long onfly_reqs
can actually survive.
It seems the scheduler
receives reqs from DPC almost immediately.
Meanwhile, onfly_reqs
are appended in process_input_requests
then excluded at the end of get_next_batch_to_run
per 40 iterations.
Given this flow, the comment here might be misleading.
return list(self.shared_state.local_tokens) | ||
|
||
def set_shared_local_tokens(self, data: List[int]): | ||
self.shared_state.local_tokens = data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if this is intentional, but in both set_* functions, passing a Python list directly into a multiprocessing.Manager().List() could replace the managed object, losing the cross-process synchronization.
"Please increase gather_tensor_size and onfly_info_size." | ||
) | ||
# The maximum size of the tensor used for gathering data from all workers. | ||
gather_tensor_size = 512 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be assert len(recv_list) < 511?
Or recv_tensor could be 1 + 1 + 511 in length, which would exceed 512.
Also, as the Gemini bot mentioned, holding_tokens_list is misleading when it’s actually length of tokens rather than a list — especially since it’s later mixed with real list operations.
LGTM |
Motivation & Modifications
When DP attention is enabled, the system decides which DP group to dispatch a request to based on the current load of each DP group.
The load data from all DP groups is gathered once to the TP0 node, where TP0 interacts with the dispatcher via shared memory to share the load information in real time.
The load data consists of two parts:
holding_tokens: the number of tokens currently being processed by each DP group's scheduler.
onfly_req: the requests that have been dispatched by the dispatcher but have not yet been accepted by the scheduler (a scheduler only accepts a request when it reaches the recv_req phase).
The dispatcher sums these two parts and selects the DP group with the lowest load.
The advantage of using onfly is that we only need to track the tokens currently being processed by each worker, without worrying about whether it's MTP, overlap, etc.
I added an extra field dp_balance_id: int to the request, so we just need to gather this integer.
To enable this feature, simply add
--load-balance-method minimum_tokens
to the startup arguments.Performence
Both TBT and TTFT have improvements.
minimum_token:
rr:
Checklist