-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Description
@xiezhq-hermann
The background of the problem we described:
We use HiRadixCache in the scenario of PD separation, write_back strategy. The local radix tree will send update events when nodes are added and deleted in rank 0, and the global radix tree will be adjusted according to the update events. When the request comes, we first match according to the global radix tree, and decide to choose P nodes and D nodes according to the number of prefix matches and load. We found that the number of matches in the global tree is sometimes much larger than the number of matches in the local number under the premise of distinguishing between instances. It looks like the host indices is not matched.
In the process of troubleshooting the problem, we encountered the following problems:
1、pending_nodes
is not used
sglang/python/sglang/srt/mem_cache/hiradix_cache.py
Lines 141 to 179 in 8f783c1
def evict(self, num_tokens: int): | |
leaves = self._collect_leaves_device() | |
heapq.heapify(leaves) | |
num_evicted = 0 | |
pending_nodes = [] | |
while num_evicted < num_tokens and len(leaves): | |
x = heapq.heappop(leaves) | |
if x.lock_ref > 0: | |
continue | |
if x.host_value is None: | |
if self.cache_controller.write_policy == "write_back": | |
num_evicted += self.write_backup(x) | |
elif self.cache_controller.write_policy == "write_through_selective": | |
num_evicted += self._evict_write_through_selective(x) | |
else: | |
assert ( | |
self.cache_controller.write_policy != "write_through" | |
), "write_through should be inclusive" | |
raise NotImplementedError | |
else: | |
num_evicted += self._evict_write_through(x) | |
for child in x.parent.children.values(): | |
if child in pending_nodes: | |
continue | |
if not child.evicted: | |
break | |
else: | |
# all children are evicted or no children | |
heapq.heappush(leaves, x.parent) | |
if self.cache_controller.write_policy == "write_back": | |
# blocking till all write back complete | |
while len(self.ongoing_write_through) > 0: | |
self.writing_check() | |
time.sleep(0.1) |
pending_nodes
is not used, this will cause the parent node not to be placed in the heap. Maybe assigned a value here:
if self.cache_controller.write_policy == "write_back":
num_evicted += self.write_backup(x)
----> pending_nodes.append(x) <----
2、token_to_kv_pool_allocator not release device_indices if write_policy is write_back
sglang/python/sglang/srt/mem_cache/hiradix_cache.py
Lines 107 to 121 in 06d0a3d
def writing_check(self): | |
queue_size = torch.tensor( | |
self.cache_controller.ack_write_queue.qsize(), dtype=torch.int | |
) | |
if torch.distributed.get_world_size(group=self.tp_group) > 1: | |
# synchrnoize TP workers to make the same update to radix cache | |
torch.distributed.all_reduce( | |
queue_size, | |
op=torch.distributed.ReduceOp.MIN, | |
group=self.tp_group, | |
) | |
for _ in range(queue_size.item()): | |
ack_id = self.cache_controller.ack_write_queue.get() | |
self.dec_lock_ref(self.ongoing_write_through[ack_id]) | |
del self.ongoing_write_through[ack_id] |
token_to_kv_pool_allocator
not release device_indices if write_policy is write_back, maybe it should be released in the writing_check function
def writing_check(self):
...
for _ in range(queue_size.item()):
ack_id = self.cache_controller.ack_write_queue.get()
self.dec_lock_ref(self.ongoing_write_through[ack_id])
----> self._evict_write_through(self.ongoing_write_through[ack_id]) <----
del self.ongoing_write_through[ack_id]
3、inc_lock_ref function in writing_check causes all parent nodes to be locked.
This function of inc_lock_ref
called by write_check
wiil lock node parents so that parent cannot be added to the heap
sglang/python/sglang/srt/mem_cache/hiradix_cache.py
Lines 90 to 93 in 06d0a3d
if host_indices is not None: | |
node.host_value = host_indices | |
self.ongoing_write_through[node.id] = node | |
self.inc_lock_ref(node) |
sglang/python/sglang/srt/mem_cache/hiradix_cache.py
Lines 148 to 151 in 06d0a3d
x = heapq.heappop(leaves) | |
if x.lock_ref > 0: | |
continue |