Skip to content

[Kernel] Optimize triton decoding kernels for long context #2271

@merrymercy

Description

@merrymercy

We noticed the current triton decoding kernel is very slow on long context. This is due to a missing flash decoding like optimization.

Reproduce

We test the decoding speed with a context length of 200 and 2,000.

triton backend: The decoding speed drops from 147.64 token/s to 126.41 token/s

$ python3 -m sglang.bench_offline_throughput --model meta-llama/Llama-3.1-8B-Instruct --dataset-name random --num-prompt 1 --random-input 128 --random-output 2048 --random-range 1 --attention-backend triton

[2024-11-30 05:10:04 TP0] Decode batch. #running-req: 1, #token: 234, token usage: 0.00, gen throughput (token/s): 147.64, #queue-req: 0
... 
[2024-11-30 05:10:18 TP0] Decode batch. #running-req: 1, #token: 2154, token usage: 0.00, gen throughput (token/s): 126.41, #queue-req: 0

flashinfer backend: The decoding speed only drops from 144.17 token/s to 143.35 token/s

$ python3 -m sglang.bench_offline_throughput --model meta-llama/Llama-3.1-8B-Instruct --dataset-name random --num-prompt 1 --random-input 128 --random-output 2048 --random-range 1

[2024-11-30 05:11:40 TP0] Decode batch. #running-req: 1, #token: 234, token usage: 0.00, gen throughput (token/s): 144.17, #queue-req: 0
...
[2024-11-30 05:11:54 TP0] Decode batch. #running-req: 1, #token: 2154, token usage: 0.00, gen throughput (token/s): 143.35, #queue-req: 0

Possible solutions

We can learn from the flash decoding triton kernel from lightllm and improve the current triton decoding kernel. Related links:

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions