Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 101 additions & 30 deletions python/sglang/srt/layers/attention/flashinfer_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ class DecodeMetadata:

@dataclass
class PrefillMetadata:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper]
prefill_wrappers: List[
Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
]
use_ragged: bool
extend_no_prefix: bool

Expand Down Expand Up @@ -160,16 +162,36 @@ def __init__(
self.decode_wrappers = []
for _ in range(self.num_wrappers):
if not skip_prefill:
self.prefill_wrappers_paged.append(
BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
backend="fa2",
if (
self.enable_flashinfer_mla
and not global_server_args_dict["disable_radix_cache"]
):
# use mla paged prefill
self.prefill_wrappers_paged.append(
BatchMLAPagedAttentionWrapper(
self.workspace_buffer,
backend="fa2",
)
)
self.prefill_wrappers_verify.append(
BatchMLAPagedAttentionWrapper(
self.workspace_buffer,
backend="fa2",
)
)
else:
self.prefill_wrappers_paged.append(
BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
backend="fa2",
)
)
self.prefill_wrappers_verify.append(
BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer, "NHD"
)
)
)
self.prefill_wrappers_verify.append(
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
)
if self.enable_flashinfer_mla:
self.decode_wrappers.append(
BatchMLAPagedAttentionWrapper(self.workspace_buffer, backend="fa2")
Expand Down Expand Up @@ -237,7 +259,10 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
else:
prefix_lens = forward_batch.extend_prefix_lens

if self.is_multimodal:
if self.is_multimodal or (
self.enable_flashinfer_mla
and not global_server_args_dict["disable_radix_cache"]
):
use_ragged = False
extend_no_prefix = False
else:
Expand Down Expand Up @@ -419,23 +444,43 @@ def forward_extend(

logits_soft_cap = layer.logit_cap

o1, _ = self.prefill_wrapper_ragged.forward_return_lse(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim),
v.view(-1, layer.tp_v_head_num, layer.v_head_dim),
causal=True,
sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap,
)

o = o1
if global_server_args_dict["disable_radix_cache"]:
# use mla ragged prefill
o, _ = self.prefill_wrapper_ragged.forward_return_lse(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim),
v.view(-1, layer.tp_v_head_num, layer.v_head_dim),
causal=True,
sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap,
)

if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
)
else:
# use mla paged prefill
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
self._get_wrapper_idx(layer)
]
if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v
)
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)

o = prefill_wrapper_paged.run(
qall[:, :, : layer.v_head_dim],
qall[:, :, layer.v_head_dim :],
k_buf[:, :, : layer.v_head_dim],
k_buf[:, :, layer.v_head_dim :],
)

return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
Expand Down Expand Up @@ -800,7 +845,9 @@ def update(
seq_lens: torch.Tensor,
seq_lens_sum: int,
prefix_lens: torch.Tensor,
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
prefill_wrappers: List[
Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
Expand All @@ -814,7 +861,9 @@ def update_single_wrapper(
seq_lens: torch.Tensor,
seq_lens_sum: int,
prefix_lens: torch.Tensor,
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
prefill_wrappers: List[
Union[BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
Expand Down Expand Up @@ -923,7 +972,9 @@ def update_cross_attention(
def call_begin_forward(
self,
wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper,
wrapper_paged: BatchPrefillWithPagedKVCacheWrapper,
wrapper_paged: Union[
BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
],
req_pool_indices: torch.Tensor,
paged_kernel_lens: torch.Tensor,
paged_kernel_lens_sum: int,
Expand Down Expand Up @@ -1004,6 +1055,26 @@ def call_begin_forward(
custom_mask=custom_mask,
non_blocking=True,
)
elif (
global_config.enable_flashinfer_mla
and not global_server_args_dict["disable_radix_cache"]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhyncs A quick question, I feel if I want to make MTP work with flashinfer backend. During the target verify stage, it is this code block has to be run? because perfill increment computation has to use the absorb trick with flashinfer?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The full paged version is a temporary solution. I will soon support ragged prefill + paged prefill + paged decoding. These days have been quite busy, so I haven't updated.

):
# mla paged prefill
kv_len_arr = kv_indptr[1:] - kv_indptr[:-1]
wrapper_paged.plan(
qo_indptr,
kv_indptr,
kv_indices,
kv_len_arr,
self.num_qo_heads,
512,
64,
1,
True,
1 / math.sqrt(192),
self.data_type,
self.data_type,
)


class FlashInferMultiStepDraftBackend:
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
"enable_ep_moe": ServerArgs.enable_ep_moe,
"device": ServerArgs.device,
"enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
"disable_radix_cache": ServerArgs.disable_radix_cache,
}

logger = logging.getLogger(__name__)
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def __init__(
"enable_ep_moe": server_args.enable_ep_moe,
"device": server_args.device,
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
"disable_radix_cache": server_args.disable_radix_cache,
}
)

Expand Down
7 changes: 5 additions & 2 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,8 +511,11 @@ def forward(
forward_batch: ForwardBatch,
) -> torch.Tensor:
if global_server_args_dict["enable_flashinfer_mla"]:
if forward_batch.forward_mode.is_extend():
return self.forward_normal(positions, hidden_states, forward_batch)
if global_server_args_dict["disable_radix_cache"]:
if forward_batch.forward_mode.is_extend():
return self.forward_normal(positions, hidden_states, forward_batch)
else:
return self.forward_absorb(positions, hidden_states, forward_batch)
else:
return self.forward_absorb(positions, hidden_states, forward_batch)
else:
Expand Down
Loading