Skip to content

Conversation

WANG-GH
Copy link
Contributor

@WANG-GH WANG-GH commented Jun 20, 2025

Motivation & Modifications

When DP attention is enabled, the system decides which DP group to dispatch a request to based on the current load of each DP group.

The load data from all DP groups is gathered once to the TP0 node, where TP0 interacts with the dispatcher via shared memory to share the load information in real time.

The load data consists of two parts:

  • holding_tokens: the number of tokens currently being processed by each DP group's scheduler.

  • onfly_req: the requests that have been dispatched by the dispatcher but have not yet been accepted by the scheduler (a scheduler only accepts a request when it reaches the recv_req phase).

The dispatcher sums these two parts and selects the DP group with the lowest load.

The advantage of using onfly is that we only need to track the tokens currently being processed by each worker, without worrying about whether it's MTP, overlap, etc.
I added an extra field dp_balance_id: int to the request, so we just need to gather this integer.

To enable this feature, simply add --load-balance-method minimum_tokens to the startup arguments.

Performence

Both TBT and TTFT have improvements.

python3 -m sglang.bench_serving --backend sglang --num-prompt 100

minimum_token:

python3 -m sglang.launch_server \
    --model-path /DeepSeek-R1 \
    --tp 16 \
    --nnodes 2 --node-rank 0 --trust-remote-code \
    --load-balance-method minimum_tokens \
    --dp-size 16 --cuda-graph-max-bs 128 \
    --enable-dp-attention 

============ Serving Benchmark Result ============
Backend:                                 sglang    
Traffic request rate:                    inf       
Max request concurrency:                 not set   
Successful requests:                     100       
Benchmark duration (s):                  57.58     
Total input tokens:                      34308     
Total generated tokens:                  21395     
Total generated tokens (retokenized):    21292     
Request throughput (req/s):              1.74      
Input token throughput (tok/s):          595.83    
Output token throughput (tok/s):         371.57    
Total token throughput (tok/s):          967.40    
Concurrency:                             31.53     
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   18153.78  
Median E2E Latency (ms):                 15861.65  
---------------Time to First Token----------------
Mean TTFT (ms):                          4767.87   
Median TTFT (ms):                        4254.61   
P99 TTFT (ms):                           9760.97   
---------------Inter-Token Latency----------------
Mean ITL (ms):                           63.03     
Median ITL (ms):                         41.31     
P95 ITL (ms):                            57.08     
P99 ITL (ms):                            864.98    
Max ITL (ms):                            7880.85   
==================================================

rr:

python3 -m sglang.launch_server \
    --model-path /DeepSeek-R1 \
    --tp 16  \
    --nnodes 2 --node-rank 0 --trust-remote-code \
    --dp-size 16 --cuda-graph-max-bs 128 \
    --enable-dp-attention 

============ Serving Benchmark Result ============
Backend:                                 sglang    
Traffic request rate:                    inf       
Max request concurrency:                 not set   
Successful requests:                     100       
Benchmark duration (s):                  63.01     
Total input tokens:                      34308     
Total generated tokens:                  21395     
Total generated tokens (retokenized):    21309     
Request throughput (req/s):              1.59      
Input token throughput (tok/s):          544.44    
Output token throughput (tok/s):         339.52    
Total token throughput (tok/s):          883.97    
Concurrency:                             36.49     
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   22992.97  
Median E2E Latency (ms):                 20915.81  
---------------Time to First Token----------------
Mean TTFT (ms):                          6945.23   
Median TTFT (ms):                        7699.87   
P99 TTFT (ms):                           14878.23  
---------------Inter-Token Latency----------------
Mean ITL (ms):                           75.57     
Median ITL (ms):                         41.34     
P95 ITL (ms):                            55.28     
P99 ITL (ms):                            1460.84   
Max ITL (ms):                            12318.38  
==================================================

Checklist

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @WANG-GH, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the data parallel attention mechanism by introducing a new, more intelligent load balancing strategy. The core idea is to minimize the total token load across DP groups by considering both currently processed and in-flight requests, leveraging shared memory for efficient state synchronization. This change is expected to improve overall serving performance and resource utilization.

Highlights

  • New Load Balancing Method: I've introduced a new load balancing method for Data Parallel (DP) attention called minimum_tokens. This method aims to distribute incoming requests more efficiently by dispatching them to the DP group with the lowest total token load.
  • Load Calculation Logic: The 'minimum_tokens' method calculates the total load for each DP group by summing two components: holding_tokens (tokens currently being processed by the scheduler) and onfly_req (requests dispatched by the controller but not yet accepted by the scheduler). This provides a more comprehensive view of each group's real-time load.
  • Shared Memory for Load Data: To facilitate real-time load balancing, I've implemented a shared memory mechanism (DPBalanceMeta). The TP0 node (dispatcher) gathers load data from all DP groups via this shared memory, allowing it to make informed dispatching decisions. A unique dp_balance_id is added to each request to track its status in the 'on-the-fly' queue.
  • Performance Improvements: Initial benchmarks show that using the minimum_tokens load balancing method leads to improvements in both Time-to-First-Token (TTFT) and Total Batch Throughput (TBT) compared to the default round-robin method, indicating better overall system performance.
  • Shared Memory Cleanup: I've added explicit cleanup for the POSIX shared memory segments (sglang_dp_balance_onfly_info, sglang_dp_balance_local_tokens) during process termination. This prevents shared memory leaks, which could otherwise cause issues with subsequent server launches.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new load balancing strategy, minimum_tokens, for Data Parallel (DP) attention. This method aims to improve performance by dispatching requests to the DP group with the fewest combined holding and in-flight tokens. Key changes include:

  • A new DPBalanceMeta class to manage shared memory for tracking token loads (onfly_info and local_tokens) across DP workers, including initialization and cleanup logic.
  • Modifications to the DataParallelController to incorporate the new scheduling strategy, initialize DPBalanceMeta, and ensure shared memory is released.
  • Updates to the Scheduler to interact with DPBalanceMeta, gather load information from DP workers, and update shared memory via the TP0 rank.
  • Additions to io_struct to include dp_balance_id in request objects for tracking.
  • Enhanced kill_process_tree utility to clean up the newly introduced shared memory segments.

The core logic for load balancing and shared memory management appears sound, with proper locking mechanisms. However, there's a critical issue in how local_tokens shared memory is handled (missing size header and assertion), and some areas could benefit from improved clarity with constants and variable naming. The performance benchmarks provided in the PR description show improvements with the new method.

Comment on lines 82 to 96
def get_shared_local_tokens(self) -> List[int]:
shm = shared_memory.SharedMemory(name=self.shm_name_local_tokens)
serialized_data = bytes(shm.buf)
worker_onfly_data = pickle.loads(serialized_data)
shm.close()
return worker_onfly_data

def set_shared_local_tokens(self, data: List[int]):
serialized_data = pickle.dumps(data)
data_size = len(serialized_data)

shm = shared_memory.SharedMemory(name=self.shm_name_local_tokens)
shm.buf[:data_size] = serialized_data

shm.close()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a critical issue in how local_tokens shared memory is handled:

  1. set_shared_local_tokens writes the pickled data but does not prefix it with its size, unlike set_shared_onfly_info.
  2. get_shared_local_tokens reads the entire shared memory buffer (bytes(shm.buf)) and attempts to unpickle it. If the actual pickled data is smaller than self.local_tokens_size, pickle.loads will try to interpret trailing garbage data, leading to potential errors or incorrect deserialization.
  3. set_shared_local_tokens does not assert that the serialized_data fits within self.local_tokens_size.

To fix this, set_shared_local_tokens should be modified to first write the size of the serialized data (using struct.pack like in set_shared_onfly_info), and get_shared_local_tokens should read this size to correctly slice the buffer before unpickling. An assertion for data size should also be added to set_shared_local_tokens.

Comment on lines 22 to 25
self.onfly_info_size = (
512 * num_workers * 8
) # max_onfly_req_per_worker * num_workers * dByte
self.local_tokens_size = num_workers * 8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The calculation for onfly_info_size and local_tokens_size uses a magic number 8 (commented as dByte).

  1. For onfly_info_size, 512 represents max_onfly_req_per_worker. The * 8 seems to be an estimate for the pickled size of dictionary entries. While the assertion in set_shared_onfly_info provides safety, consider making MAX_ONFLY_REQ_PER_WORKER = 512 and ESTIMATED_BYTES_PER_ENTRY = 8 (or a more descriptive name if 8 has a specific meaning) as named constants for better readability and maintainability.

  2. For local_tokens_size = num_workers * 8, this calculation might be too small for storing a pickled list of num_workers integers, especially considering Python's pickle overhead. For example, pickle.dumps([0]*16) is 50 bytes, not 16*8=128. More importantly, set_shared_local_tokens currently lacks an assertion to check if the pickled data fits within local_tokens_size. This, combined with the issue in get_shared_local_tokens (see separate comment), is problematic.

# 2. write the new onfly info to the shm
self.balance_meta.set_shared_onfly_info(onfly_info)

logger.info(f"dp workers {local_tokens=}, {onfly_info=}, {target_worker=}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This log message uses logger.info and includes potentially large data structures (local_tokens, onfly_info). If requests are frequent, this could lead to excessive logging and performance overhead. Consider changing this to logger.debug or making it conditional, for example, logging only every N requests or if a specific debug flag is enabled.

init_onfly_req = [{} for _ in range(num_workers)]
self.set_shared_local_tokens(init_local_tokens)
self.set_shared_onfly_info(init_onfly_req)
self.shm1.name
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This line self.shm1.name accesses the name attribute of the shared memory object but does not use the value. It has no side effect and can be removed if it's not intended for a specific purpose (e.g., debugging during development).

@@ -1757,6 +1776,91 @@ def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
deepep_mode=DeepEPMode[self.server_args.deepep_mode],
)

def handle_dp_balance_data(self, local_batch: ScheduleBatch):
def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The parameter holding_tokens_list in gather_dp_balance_info is misleading. Based on its usage in handle_dp_balance_data (line 1854-1855), it's a single integer representing the current worker's holding tokens, not a list.

Later, within gather_dp_balance_info (line 1814), holding_tokens_list is re-assigned to be a list of holding tokens gathered from all workers if self.tp_rank == 0.

This dual meaning and misnaming can cause confusion. Consider renaming the parameter to something like current_worker_holding_tokens: int and the locally gathered list to all_workers_holding_tokens: List[int].

Comment on lines 1783 to 1871
assert len(recv_list) <= 511, (
"The number of requests received this round is too large. "
"Please increase gather_tensor_size and onfly_info_size."
)

gather_tensor_size = 512
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The values 511 and 512 (for gather_tensor_size) are used here and seem related to max_onfly_req_per_worker defined implicitly in DPBalanceMeta. It would be better to define these as named constants, possibly in data_parallel_meta.py or a shared constants module, and import them here. This improves readability and makes it easier to update if the underlying limits change.

Comment on lines 1780 to 1905
def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]:
"""gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
recv_list = self.recv_dp_balance_id_this_term
assert len(recv_list) <= 511, (
"The number of requests received this round is too large. "
"Please increase gather_tensor_size and onfly_info_size."
)

gather_tensor_size = 512
# recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
recv_tensor[0] = holding_tokens_list
recv_tensor[1] = len(
recv_list
) # The first element is the length of the list.
recv_tensor[2 : len(recv_list) + 2] = torch.tensor(
recv_list, dtype=torch.int32
)

if self.tp_rank == 0:
gathered_list = [
torch.zeros(gather_tensor_size, dtype=torch.int32)
for _ in range(self.balance_meta.num_workers)
]
else:
gathered_list = None

torch.distributed.gather(
recv_tensor, gathered_list, group=self.tp_cpu_group
)

gathered_id_list_per_worker = None
if self.tp_rank == 0:
gathered_id_list_per_worker = []
holding_tokens_list = []
for tensor in gathered_list:
holding_tokens_list.append(tensor[0].item())
list_length = tensor[1].item()
gathered_id_list_per_worker.append(
tensor[2 : list_length + 2].tolist()
)

return gathered_id_list_per_worker, holding_tokens_list
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The return type hint Union[None, List[List[int]]] for gather_dp_balance_info is incorrect. The function actually returns a tuple: (gathered_id_list_per_worker, holding_tokens_list).
Based on the logic:

  • If self.tp_rank != 0, it returns (None, input_argument_holding_tokens).
  • If self.tp_rank == 0, it returns (List[List[int]], List[int]) (where the second list is the gathered holding tokens for all workers).

A more accurate type hint would be Tuple[Optional[List[List[int]]], List[int]].

Suggested change
def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]:
"""gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
recv_list = self.recv_dp_balance_id_this_term
assert len(recv_list) <= 511, (
"The number of requests received this round is too large. "
"Please increase gather_tensor_size and onfly_info_size."
)
gather_tensor_size = 512
# recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
recv_tensor[0] = holding_tokens_list
recv_tensor[1] = len(
recv_list
) # The first element is the length of the list.
recv_tensor[2 : len(recv_list) + 2] = torch.tensor(
recv_list, dtype=torch.int32
)
if self.tp_rank == 0:
gathered_list = [
torch.zeros(gather_tensor_size, dtype=torch.int32)
for _ in range(self.balance_meta.num_workers)
]
else:
gathered_list = None
torch.distributed.gather(
recv_tensor, gathered_list, group=self.tp_cpu_group
)
gathered_id_list_per_worker = None
if self.tp_rank == 0:
gathered_id_list_per_worker = []
holding_tokens_list = []
for tensor in gathered_list:
holding_tokens_list.append(tensor[0].item())
list_length = tensor[1].item()
gathered_id_list_per_worker.append(
tensor[2 : list_length + 2].tolist()
)
return gathered_id_list_per_worker, holding_tokens_list
def gather_dp_balance_info(current_worker_holding_tokens: int) -> Tuple[Optional[List[List[int]]], List[int]]:

Comment on lines +296 to +297
# while onfly refers to the requests dispatched by the dispatcher but not yet received by the scheduler.
onfly_info = self.balance_meta.get_shared_onfly()
Copy link
Contributor

@Edenzzzz Edenzzzz Jun 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be more formal to name these as "on_the_fly" or "in_flight"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, we can first discuss the current implementation of the algorithm, and then I’ll make a unified update accordingly.

@Edenzzzz
Copy link
Contributor

I see, previously it only uses round robin without checking token count

@JustinTong0323
Copy link
Collaborator

I think you should clean the commit history🤣

@WANG-GH
Copy link
Contributor Author

WANG-GH commented Jun 30, 2025

I've just cleaned up the commit history and resolved the conflicts with the main branch. Could you please review this PR?
@JustinTong0323

@JustinTong0323
Copy link
Collaborator

Plz fix the lint, you could ref to contribution guide.

@WANG-GH
Copy link
Contributor Author

WANG-GH commented Jul 1, 2025

Hi, i finished the lint error, could you please trigger the ci?

"Please increase gather_tensor_size and onfly_info_size."
)

gather_tensor_size = 512
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add explanation for this value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it means the maximum size of the tensor used for gathering data.

def get_next_global_balance_id() -> int:
INT32_MAX = 2147483647
current_id = self.global_balance_id
self.global_balance_id = (self.global_balance_id + 1) % INT32_MAX
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain the meaning of this variable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, this variable corresponds to the balance_id in TokenizedGenerateReqInput.
We use it to to control the number of onfly tokens (requests dispatched to workers but not yet received).

def __init__(self, num_workers: int):
self.num_workers = num_workers
self._manager = mp.Manager()
self.mutex = self._manager.Lock()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking if we could abstract some methods to python/sglang/srt/utils.py and python/sglang/srt/distributed/parallel_state.py ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, i moved it to the manager/utils.py

@WANG-GH
Copy link
Contributor Author

WANG-GH commented Jul 25, 2025

I tested on 2*node H20 with 4096 reqs

# test command
python3 -m sglang.bench_serving --backend sglang --num-prompts 4096 --dataset-path /sgl-workspace/sharegpt.json

rr:

============ Serving Benchmark Result ============
Backend:                                 sglang    
Traffic request rate:                    inf       
Max request concurrency:                 not set   
Successful requests:                     4096      
Benchmark duration (s):                  263.31    
Total input tokens:                      1294121   
Total generated tokens:                  787217    
Total generated tokens (retokenized):    783897    
Request throughput (req/s):              15.56     
Input token throughput (tok/s):          4914.74   
Output token throughput (tok/s):         2989.65   
Total token throughput (tok/s):          7904.38   
Concurrency:                             2123.68   
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   136522.09 
Median E2E Latency (ms):                 130958.79 
---------------Time to First Token----------------
Mean TTFT (ms):                          60330.23  
Median TTFT (ms):                        63801.38  
P99 TTFT (ms):                           113752.40 
---------------Inter-Token Latency----------------
Mean ITL (ms):                           396.08    
Median ITL (ms):                         143.72    
P95 ITL (ms):                            654.97    
P99 ITL (ms):                            1512.92   
Max ITL (ms):                            105775.01 
==================================================

minimum token:

============ Serving Benchmark Result ============
Backend:                                 sglang    
Traffic request rate:                    inf       
Max request concurrency:                 not set   
Successful requests:                     4096      
Benchmark duration (s):                  262.36    
Total input tokens:                      1294121   
Total generated tokens:                  787217    
Total generated tokens (retokenized):    783888    
Request throughput (req/s):              15.61     
Input token throughput (tok/s):          4932.61   
Output token throughput (tok/s):         3000.52   
Total token throughput (tok/s):          7933.12   
Concurrency:                             2136.54   
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   136851.55 
Median E2E Latency (ms):                 129838.93 
---------------Time to First Token----------------
Mean TTFT (ms):                          61737.04  
Median TTFT (ms):                        64347.30  
P99 TTFT (ms):                           111977.50 
---------------Inter-Token Latency----------------
Mean ITL (ms):                           390.17    
Median ITL (ms):                         143.37    
P95 ITL (ms):                            567.51    
P99 ITL (ms):                            1079.40   
Max ITL (ms):                            104110.49 
==================================================

@ltdo111
Copy link

ltdo111 commented Jul 31, 2025

I've been following your PR for a long time. I was wondering when it will be merged into the main repository. Do you have a plan for that?

@WANG-GH
Copy link
Contributor Author

WANG-GH commented Jul 31, 2025

I've been following your PR for a long time. I was wondering when it will be merged into the main repository. Do you have a plan for that?

Thank you for the reply! I also hope to merge it into the main branch as soon as possible. Over the past few weeks, I've been actively communicating with Qiaolin and Cheng Wan from the SGLang community, and they have provided many great suggestions for this PR. We are currently confirming with the PD separation team to check if there are any conflicts between our designs.
@ch-wan @Qiaolin-Yu

@ltdo111
Copy link

ltdo111 commented Jul 31, 2025

I've been following your PR for a long time. I was wondering when it will be merged into the main repository. Do you have a plan for that?

Thank you for the reply! I also hope to merge it into the main branch as soon as possible. Over the past few weeks, I've been actively communicating with Qiaolin and Cheng Wan from the SGLang community, and they have provided many great suggestions for this PR. We are currently confirming with the PD separation team to check if there are any conflicts between our designs. @ch-wan @Qiaolin-Yu

I've read through your PR thoroughly and have two points I'd like to confirm with you:

The first question: There is a get_load method in the scheduler, but it isn't reused in this PR. After reading through your code, my understanding is that the current minimum token load balancing method takes into account the load brought by output token IDs during the decode phase, making the load statistics more accurate. I wonder if this understanding is correct.

The second question is about the implementation of data_parallel_controller and IPC communication. The community originally adopted a unified communication method using zmq, but this PR introduces a new mp API to implement communication between dpc and the scheduler. Have you considered switching the communication method from mp to IPC communication? This would be more consistent with the community style, provide a single unified implementation for inter-process communication, and make it easier for subsequent open-source contributors to read the code.

@PanXun2
Copy link

PanXun2 commented Jul 31, 2025

@WANG-GH Glad to know this is actively developed. Currently, the Router has the load balancer function which get_loads from each sglang instance. It has the function for collecting workloads. Do you have any plan to reuse such loads information? Or any efforts to make code better reused for these similar functions?

@WANG-GH
Copy link
Contributor Author

WANG-GH commented Jul 31, 2025

I've read through your PR thoroughly and have two points I'd like to confirm with you:

The first question: There is a get_load method in the scheduler, but it isn't reused in this PR. After reading through your code, my understanding is that the current minimum token load balancing method takes into account the load brought by output token IDs during the decode phase, making the load statistics more accurate. I wonder if this understanding is correct.

The second question is about the implementation of data_parallel_controller and IPC communication. The community originally adopted a unified communication method using zmq, but this PR introduces a new mp API to implement communication between dpc and the scheduler. Have you considered switching the communication method from mp to IPC communication? This would be more consistent with the community style, provide a single unified implementation for inter-process communication, and make it easier for subsequent open-source contributors to read the code.

Thanks for ask.

For the first question: I hadn't noticed the get_load function before. When calculating the load, this function, combined with the current batch size, can better accommodate PD separation.

For the second question: IPC communication is more suitable for the producer-consumer model, but the current controller needs to frequently check the global instance's status. This programming paradigm is more like using locks to maintain a global state, which doesn't align well with the producer-consumer model. Changing to IPC would be quite awkward.

@WANG-GH
Copy link
Contributor Author

WANG-GH commented Jul 31, 2025

@WANG-GH Glad to know this is actively developed. Currently, the Router has the load balancer function which get_loads from each sglang instance. It has the function for collecting workloads. Do you have any plan to reuse such loads information? Or any efforts to make code better reused for these similar functions?

I hadn't noticed the get_load function before. When calculating the load, this function, combined with the current batch size, can better accommodate PD separation. I will refactor the handle_dp_balance_data func to use get_loads before gather.

@ch-wan ch-wan self-assigned this Jul 31, 2025
@ltdo111
Copy link

ltdo111 commented Jul 31, 2025

I've read through your PR thoroughly and have two points I'd like to confirm with you:
The first question: There is a get_load method in the scheduler, but it isn't reused in this PR. After reading through your code, my understanding is that the current minimum token load balancing method takes into account the load brought by output token IDs during the decode phase, making the load statistics more accurate. I wonder if this understanding is correct.
The second question is about the implementation of data_parallel_controller and IPC communication. The community originally adopted a unified communication method using zmq, but this PR introduces a new mp API to implement communication between dpc and the scheduler. Have you considered switching the communication method from mp to IPC communication? This would be more consistent with the community style, provide a single unified implementation for inter-process communication, and make it easier for subsequent open-source contributors to read the code.

Thanks for ask.

For the first question: I hadn't noticed the get_load function before. When calculating the load, this function, combined with the current batch size, can better accommodate PD separation.

For the second question: IPC communication is more suitable for the producer-consumer model, but the current controller needs to frequently check the global instance's status. This programming paradigm is more like using locks to maintain a global state, which doesn't align well with the producer-consumer model. Changing to IPC would be quite awkward.

thx for reply, i see。

@ch-wan ch-wan added ready-to-merge The PR is ready to merge after the CI is green. and removed ready-to-merge The PR is ready to merge after the CI is green. labels Aug 3, 2025
Copy link

@ollybbmonster ollybbmonster left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx for your work on this feature, it helps me a lot.

with self.balance_meta.mutex:
# 1. local_tokens represents the tokens currently inferring on the worker,
# while onfly refers to the requests dispatched by the dispatcher but not yet received by the scheduler.
onfly_info = self.balance_meta.get_shared_onfly()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m wondering how long onfly_reqs can actually survive.
It seems the scheduler receives reqs from DPC almost immediately.
Meanwhile, onfly_reqs are appended in process_input_requests then excluded at the end of get_next_batch_to_run per 40 iterations.

Given this flow, the comment here might be misleading.

return list(self.shared_state.local_tokens)

def set_shared_local_tokens(self, data: List[int]):
self.shared_state.local_tokens = data

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this is intentional, but in both set_* functions, passing a Python list directly into a multiprocessing.Manager().List() could replace the managed object, losing the cross-process synchronization.

"Please increase gather_tensor_size and onfly_info_size."
)
# The maximum size of the tensor used for gathering data from all workers.
gather_tensor_size = 512

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be assert len(recv_list) < 511?
Or recv_tensor could be 1 + 1 + 511 in length, which would exceed 512.

Also, as the Gemini bot mentioned, holding_tokens_list is misleading when it’s actually length of tokens rather than a list — especially since it’s later mixed with real list operations.

narutolhy pushed a commit to narutolhy/sglang that referenced this pull request Aug 17, 2025
narutolhy pushed a commit to narutolhy/sglang that referenced this pull request Aug 18, 2025
MahmoudAshraf97 pushed a commit to MahmoudAshraf97/sglang that referenced this pull request Sep 8, 2025
@whybeyoung
Copy link
Collaborator

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready-to-merge The PR is ready to merge after the CI is green.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants