Skip to content

Conversation

xiezhq-hermann
Copy link
Collaborator

@xiezhq-hermann xiezhq-hermann commented Jan 1, 2025

Motivation

While RadixTree-based context caching provides significant performance benefits, these gains are not always fully realized. A key bottleneck is the capacity limit of GPU memory. Currently, SGLang stores historical KV caches exclusively in GPU memory; whenever more memory is required for batch execution, existing caches are discarded.

To address this issue, we propose a hierarchical caching mechanism for LLM serving, treating GPU memory as an L1 cache, host memory as an L2 cache, and disk as an L3 cache (future). This PR introduces such a mechanism in SGLang through a separate host memory pool that backs up KV caches, allowing them to be reloaded into GPU memory when needed.

Modifications

  • A HiRadixCache that extends RadixCache with host memory addresses and synchronization mechanisms.
  • A host memory pool that synchronizes with the device memory pool of KV caches.
  • A memory controller that implements efficient data transfer between host and device, and handles various cache write policies for hierarchical caching.

Todo:

  • Update benchmark results.
  • Remove deprecated design and implementation.

Checklist

  • Format your code according to the Contributor Guide.
  • Add unit tests as outlined in the Contributor Guide.
  • Update documentation as needed, including docstrings or example tutorials.

@zhyncs
Copy link
Member

zhyncs commented Jan 1, 2025

It's amazing! Happy new year!

@zhyncs zhyncs added the enhancement New feature or request label Jan 1, 2025
@Ying1123 Ying1123 merged commit 6c7a152 into main Feb 24, 2025
17 of 21 checks passed
@Ying1123 Ying1123 deleted the xiezhq-hierarchical branch February 24, 2025 05:56
@lambert0312
Copy link
Contributor

DeepSeek MLA is not supported yet, and an error will be reported when starting the model:

  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1849, in run_scheduler_process
    scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 305, in __init__
    HiRadixCache(
  File "/sgl-workspace/sglang/python/sglang/srt/mem_cache/hiradix_cache.py", line 26, in __init__
    self.token_to_kv_pool_host = MLATokenToKVPoolHost(token_to_kv_pool)
  File "/sgl-workspace/sglang/python/sglang/srt/mem_cache/memory_pool.py", line 461, in __init__
    self.head_num = device_pool.head_num
AttributeError: 'MLATokenToKVPool' object has no attribute 'head_num'

@xiezhq-hermann
Copy link
Collaborator Author

DeepSeek MLA is not supported yet, and an error will be reported when starting the model:

  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1849, in run_scheduler_process
    scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 305, in __init__
    HiRadixCache(
  File "/sgl-workspace/sglang/python/sglang/srt/mem_cache/hiradix_cache.py", line 26, in __init__
    self.token_to_kv_pool_host = MLATokenToKVPoolHost(token_to_kv_pool)
  File "/sgl-workspace/sglang/python/sglang/srt/mem_cache/memory_pool.py", line 461, in __init__
    self.head_num = device_pool.head_num
AttributeError: 'MLATokenToKVPool' object has no attribute 'head_num'

Thank you @lambert0312 for pointing out, yes, this feature is still under meta stage and currently only supported MHA and GQA style memory pool. I will keep you posted once MLA is supported, which should be soon.
For further question about this feature, feel free to reach out to me on SGLang slack for a more prompt reply.

@lambert0312
Copy link
Contributor

lambert0312 commented Feb 25, 2025

Thank you @lambert0312 for pointing out, yes, this feature is still under meta stage and currently only supported MHA and GQA style memory pool. I will keep you posted once MLA is supported, which should be soon.
For further question about this feature, feel free to reach out to me on SGLang slack for a more prompt reply.

Thanks @xiezhq-hermann

@zhaochenyang20 zhaochenyang20 mentioned this pull request Mar 3, 2025
22 tasks
@zhyncs zhyncs mentioned this pull request Mar 4, 2025
67 tasks
@xiezhq-hermann
Copy link
Collaborator Author

xiezhq-hermann commented Mar 4, 2025

Thank you @lambert0312 for pointing out, yes, this feature is still under meta stage and currently only supported MHA and GQA style memory pool. I will keep you posted once MLA is supported, which should be soon.
For further question about this feature, feel free to reach out to me on SGLang slack for a more prompt reply.

Thanks @xiezhq-hermann

@lambert0312 just FYI, there is a PR from the community supporting MLA with hierarchical caching, which will be merged soon but feel free to check it out: #4009

@lambert0312
Copy link
Contributor

@lambert0312 just FYI, there is a PR from the community supporting MLA with hierarchical caching, which will be merged soon but feel free to check it out: #4009

@xiezhq-hermann Thanks, but I've encountered a problem. I just experimented with #4009 and found that there is indeed a concurrency problem when TP>1. The program will enter a locked state. There may be a concurrency problem. Please follow up. Thank you!

@shensimeteor
Copy link

After code cleaning and basic performance benchmark, this PR is ready to merge. You can add --enable-hierarchical-cache option when starting a SGLang server to turn on this feature. This feature will still be under active development in the future months, your feedback will be greatly welcomed : ) Following is a throughput v.s. median TTFT curve that demonstrates the benefit of hierarchical caching using a synthetic multi-turn benchmark, and you can reproduce it with Qwen/Qwen2.5-14B-Instruct on an A100-80G GPU as explained here:

throughput_latency_curve

Besides --enable-hierarchical-cache, do we also need to set cpu_offload_gb?

@xiezhq-hermann
Copy link
Collaborator Author

After code cleaning and basic performance benchmark, this PR is ready to merge. You can add --enable-hierarchical-cache option when starting a SGLang server to turn on this feature. This feature will still be under active development in the future months, your feedback will be greatly welcomed : ) Following is a throughput v.s. median TTFT curve that demonstrates the benefit of hierarchical caching using a synthetic multi-turn benchmark, and you can reproduce it with Qwen/Qwen2.5-14B-Instruct on an A100-80G GPU as explained here:
throughput_latency_curve

Besides --enable-hierarchical-cache, do we also need to set cpu_offload_gb?

Right now it allocate a host memory pool which is 4 times of the size of the device memory pool by default, so no need to set other things but more options will be added.

aoshen524 pushed a commit to aoshen524/sglang that referenced this pull request Mar 10, 2025
Co-authored-by: Wenxuan Tan <wenxuan.tan@wisc.edu>
Co-authored-by: Yineng Zhang <me@zhyncs.com>
@wangyibin-gh
Copy link

Hi I'm wondering - when are you planning to support L3 cache? I think it's reasonable to support pluggable L3 caches, which encourages storage providers to implement their L3 caches according to their product features. What you need to do is to define a bunch of kv cache apis for getting/putting/evicting kv cache chunk/item and give them some demo implentation using something like local SSD.

@msharmavikram
Copy link

This is in works @wangyibin-gh !

@wangyibin-gh
Copy link

This is in works @wangyibin-gh !

when do you expect this feature can be merged? and btw is there any documentation about it, especially w.r.t the APIs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.