Skip to content

Conversation

Alcanderian
Copy link
Collaborator

@Alcanderian Alcanderian commented Jun 6, 2025

Motivation

An inelegant workaround, thus the kernel have to be rewrited to support num_head != 128

  1. add benchmark for cutlass_mla_decode
  2. extend num_head support range
  3. improve performance up to 1.8x for page_size == 128
  4. fix split_kv thanks to bugfix: bugfix for blackwell mla split-k flashinfer-ai/flashinfer#1109, but still has some issue, ref: https://github.com/sgl-project/sglang/pull/6929/files#diff-3aef9e714fbacce71c7a8d812964d83355a6fa2953340ebc3f2087e9f0920d6eR215

Limitation:

  1. Auto dynamic split_kv is not compatitable with cuda graph
  2. While static split_kv will hang with spec. batch_size and seq_len combination

TODO futher:
Try per batch spilt kv schedule

Tables below shows GB/s of cases

  • This PR
block_size=1, num_kv_splits=-1: 
cutlass mla:
    batch_size  seq_len    128 heads     64 heads     32 heads     16 heads
0          1.0      1.0    18.775741     9.323621     7.066667     5.914849
1          1.0     64.0    18.037941     8.492891     6.369953     5.546251
2          1.0    128.0    18.086956     9.333334     7.302476     5.933333
3          1.0    256.0    19.248121    11.767562     9.880312     8.944444
4          1.0    512.0    30.251950    20.342857    18.383585    17.413024
5          1.0   1024.0    47.565761    35.746746    33.859496    32.944443
6          1.0   2048.0    88.256958    65.945943    65.831741    63.242601
7          1.0   4096.0   152.649073   124.944850   123.052637   125.459454
8          1.0   8192.0   262.871008   227.231584   225.579344   225.437931
9          8.0      1.0   144.695649    56.619274    44.741961    37.504525
10         8.0     64.0   144.892515    58.609976    44.485247    36.867314
11         8.0    128.0   144.695649    57.853107    45.039005    37.815768
12         8.0    256.0   149.645091    78.089209    65.518332    59.431869
13         8.0    512.0   226.133339   130.754661   119.121392   112.588323
14         8.0   1024.0   377.765798   232.612351   221.660749   215.636362
15         8.0   2048.0   432.146795   325.333342   316.513943   311.895788
16         8.0   4096.0   709.012474   542.171414   534.399986   530.751216
17         8.0   8192.0  1036.674388   852.300472   852.170655   848.441399
18        32.0      1.0   577.214113   231.599354   176.064888   147.588667
19        32.0     64.0   577.214113   234.631753   178.086961   152.210434
20        32.0    128.0   577.997268   238.139541   175.779761   147.948058
21        32.0    256.0   510.632211   282.115667   237.178928   214.806508
22        32.0    512.0   545.104826   362.009936   326.463364   309.074712
23        32.0   1024.0   787.352045   571.204879   541.986964   527.111092
24        32.0   2048.0  1084.185824   877.303361   844.257206   838.568101
25        32.0   4096.0  1404.869242  1214.767707  1205.797974  1197.332002
26        32.0   8192.0  1676.857473  1534.933217  1538.568754  1532.934185
27        64.0      1.0  1066.292903   408.142352   352.415598   299.789487
28        64.0     64.0  1066.292903   409.600015   357.052647   306.339501
29        64.0    128.0  1066.292903   408.142352   365.160647   299.789487
30        64.0    256.0   663.703703   377.216323   339.742771   303.058818
31        64.0    512.0   860.180284   556.555741   541.203134   513.262140
32        64.0   1024.0  1113.536499   831.596485   821.894710   807.489338
33        64.0   2048.0  1406.839498  1147.971553  1149.788455  1133.307507
34        64.0   4096.0  1678.575700  1488.085743  1502.659376  1491.733281
35        64.0   8192.0  1909.722271  1781.333347  1795.517454  1788.688455
36       128.0      1.0  1437.920745   586.939618   503.100815   455.111118
37       128.0     64.0  1437.920745   587.540995   503.100815   455.111118
38       128.0    128.0  1436.708323   587.240152   503.684463   454.543644
39       128.0    256.0  1591.783451   774.277279   717.959591   699.688051
40       128.0    512.0  1721.213083  1011.919480  1007.187507  1014.849762
41       128.0   1024.0  1874.864710  1319.901948  1347.295742  1368.024770
42       128.0   2048.0  1998.351455  1626.401955  1658.537412  1670.407781
43       128.0   4096.0  2110.493046  1875.710591  1900.277381  1911.494105
44       128.0   8192.0  2177.068320  2043.084483  2060.205067  2066.543098
45       256.0      1.0  1743.157046   695.922340   623.368264   570.268303
46       256.0     64.0  1710.779087   695.922340   616.946356   570.714700
47       256.0    128.0  1740.486203   696.344883   615.416002   570.268303
48       256.0    256.0  1816.120414   898.915094   867.961912   859.786158
49       256.0    512.0  1921.664149  1163.980050  1186.472062  1200.942143
50       256.0   1024.0  2036.203147  1465.864991  1509.020103  1533.656186
51       256.0   2048.0  2120.437226  1747.550269  1786.140684  1803.644102
52       256.0   4096.0  2173.369479  1955.855464  1982.444864  1994.506624
53       256.0   8192.0  2207.865421  2087.281599  2102.928874  2110.458164
block_size=32, num_kv_splits=-1: 
cutlass mla:
    batch_size  seq_len    128 heads     64 heads     32 heads     16 heads
0          1.0      1.0    18.062415     9.572649     7.310345     6.085470
1          1.0     64.0    18.062415     9.343065     7.310345     5.951933
2          1.0    128.0    18.062415     9.343065     7.294624     6.111588
3          1.0    256.0    19.844961    11.891323     9.629755     9.191793
4          1.0    512.0    31.190805    20.361036    18.416443    17.444146
5          1.0   1024.0    48.788008    35.746746    33.859496    32.915870
6          1.0   2048.0    88.256958    67.777775    65.888887    64.944442
7          1.0   4096.0   152.798437   124.842111   123.052637   122.258431
8          1.0   8192.0   263.098774   228.271554   225.408190   225.867995
9          8.0      1.0   144.695649    58.778189    44.705105    37.504525
10         8.0     64.0   144.892515    58.276420    43.803067    37.107491
11         8.0    128.0   144.695649    58.371334    43.803067    37.442894
12         8.0    256.0   149.958158    78.941091    66.280728    59.950546
13         8.0    512.0   231.684104   132.368917   119.380156   113.405954
14         8.0   1024.0   377.765798   234.348258   222.450146   218.743520
15         8.0   2048.0   431.016991   326.012520   316.596446   311.895788
16         8.0   4096.0   708.610309   542.655932   535.116665   530.751216
17         8.0   8192.0  1040.226086   860.881701   863.627907   849.506715
18        32.0      1.0   577.214113   231.786582   175.212267   149.648609
19        32.0     64.0   577.997268   233.866227   178.233175   150.762613
20        32.0    128.0   577.997268   232.350087   175.212267   147.112187
21        32.0    256.0   511.087323   283.589808   237.178928   214.527008
22        32.0    512.0   543.399268   366.375876   322.314764   309.380881
23        32.0   1024.0   786.502684   580.507024   541.986964   526.882409
24        32.0   2048.0  1085.971136   877.303361   862.239253   850.181812
25        32.0   4096.0  1405.264315  1216.593067  1205.797974  1197.935500
26        32.0   8192.0  1676.278843  1542.846464  1537.822054  1532.438166
27        64.0      1.0  1066.292903   410.626555   349.859785   293.750197
28        64.0     64.0  1066.292903   416.744186   347.897448   293.277546
29        64.0    128.0  1066.292903   408.433054   348.455843   291.635207
30        64.0    256.0   662.553427   377.052528   339.584529   303.058818
31        64.0    512.0   860.606541   555.919187   541.425305   513.262140
32        64.0   1024.0  1121.673843   832.121156   830.362220   807.220978
33        64.0   2048.0  1396.782673  1148.235274  1148.700919  1133.575749
34        64.0   4096.0  1670.160447  1488.541725  1497.021718  1485.440038
35        64.0   8192.0  1908.221937  1775.965089  1789.518287  1783.300448
36       128.0      1.0  1439.135072   587.540995   502.809498   455.964992
37       128.0     64.0  1436.708323   587.540995   503.100815   448.117995
38       128.0    128.0  1437.920745   586.639391   503.100815   454.827204
39       128.0    256.0  1591.783451   784.773580   714.792180   700.059438
40       128.0    512.0  1718.658071  1013.678130  1006.803082  1014.849762
41       128.0   1024.0  1864.079213  1319.571728  1347.659093  1357.701508
42       128.0   2048.0  1997.973137  1618.762558  1656.840397  1669.825246
43       128.0   4096.0  2103.165044  1868.675734  1894.450615  1904.219549
44       128.0   8192.0  2172.079669  2042.757749  2056.695144  2065.866858
45       256.0      1.0  1740.486203   695.711261   611.515499   570.714700
46       256.0     64.0  1724.195368   695.922340   611.300251   571.161745
47       256.0    128.0  1695.458759   695.922340   623.368264   572.507248
48       256.0    256.0  1814.683599   897.753450   867.961912   859.506004
49       256.0    512.0  1920.070717  1162.124772  1186.338628  1190.357360
50       256.0   1024.0  2035.137519  1461.803267  1508.792237  1525.708528
51       256.0   2048.0  2113.217676  1744.195505  1780.575755  1796.875520
52       256.0   4096.0  2172.188713  1952.514402  1978.298493  1990.330927
53       256.0   8192.0  2209.937149  2087.708246  2104.733073  2112.105217
block_size=64, num_kv_splits=-1: 
cutlass mla:
    batch_size  seq_len    128 heads     64 heads     32 heads     16 heads
0          1.0      1.0    18.037941     9.431579     7.066667     5.939520
1          1.0     64.0    18.062415     9.333334     7.066667     5.951933
2          1.0    128.0    18.062415     9.343065     7.066667     5.939520
3          1.0    256.0    19.866962    12.039042     9.820690     8.960000
4          1.0    512.0    30.353468    20.342857    18.416443    17.859102
5          1.0   1024.0    49.103447    35.777777    33.888888    32.944443
6          1.0   2048.0    88.351554    67.660309    65.888887    64.888115
7          1.0   4096.0   156.472941   124.944850   123.306135   122.258431
8          1.0   8192.0   262.416603   228.097565   227.305269   225.954205
9          8.0      1.0   145.089918    58.754099    44.815854    37.690653
10         8.0     64.0   145.089918    58.850577    44.412439    37.412152
11         8.0    128.0   144.695649    58.371334    44.087736    36.571427
12         8.0    256.0   149.489047    78.826434    66.425657    59.906977
13         8.0    512.0   226.369128   131.842310   119.813953   113.488372
14         8.0   1024.0   368.971676   236.703515   224.206746   218.586036
15         8.0   2048.0   437.305052   329.451477   320.439051   315.679325
16         8.0   4096.0   709.415096   548.786241   543.375394   546.117637
17         8.0   8192.0  1054.222185   871.854348   871.864992   869.309832
18        32.0      1.0   577.997268   235.983528   177.214692   148.672102
19        32.0     64.0   577.997268   237.744613   179.411568   151.640605
20        32.0    128.0   577.997268   235.983528   179.263415   149.403282
21        32.0    256.0   511.543248   282.850806   237.333325   215.790577
22        32.0    512.0   552.386776   361.650797   326.301827   309.380881
23        32.0   1024.0   796.817502   580.251630   543.401469   533.127277
24        32.0   2048.0  1097.264599   876.995425   862.851861   851.088518
25        32.0   4096.0  1429.382119  1232.958313  1215.902432  1206.755027
26        32.0   8192.0  1695.290845  1550.590675  1554.671405  1548.977865
27        64.0      1.0  1081.177658   426.349443   357.346486   299.789487
28        64.0     64.0  1149.754392   419.181269   357.052647   298.806563
29        64.0    128.0  1074.360676   417.047275   353.851662   299.297218
30        64.0    256.0   673.842523   381.693171   339.584529   306.581114
31        64.0    512.0   861.033221   563.002332   535.272748   519.235261
32        64.0   1024.0  1123.402167   848.447753   838.442923   814.529876
33        64.0   2048.0  1408.341741  1156.205501  1158.287131  1151.845981
34        64.0   4096.0  1687.646087  1495.184953  1503.839182  1500.444488
35        64.0   8192.0  1926.955890  1796.622628  1805.834499  1799.563057
36       128.0      1.0  1437.920745   587.240152   504.562478   456.250329
37       128.0     64.0  1437.920745   586.639391   504.269467   455.395387
38       128.0    128.0  1437.920745   586.939618   503.100815   455.111118
39       128.0    256.0  1592.888865   776.354033   713.742551   699.688051
40       128.0    512.0  1722.066442  1005.638633  1009.500150  1014.437721
41       128.0   1024.0  1878.487665  1319.571728  1355.334997  1367.254697
42       128.0   2048.0  2021.708402  1635.183253  1668.792828  1680.373441
43       128.0   4096.0  2118.096768  1880.702978  1913.376163  1923.288632
44       128.0   8192.0  2195.889346  2060.337349  2075.610715  2085.252766
45       256.0      1.0  1740.486203   696.133547   611.085154   575.898904
46       256.0     64.0  1744.942167   695.922340   616.946356   576.126444
47       256.0    128.0  1738.710183   696.133547   610.870209   576.126444
48       256.0    256.0  1815.401722   897.985495   868.220313   866.849851
49       256.0    512.0  1922.727907  1167.708489  1186.338628  1202.097603
50       256.0   1024.0  2037.270058  1471.589377  1516.806773  1535.231387
51       256.0   2048.0  2133.730160  1759.549313  1796.544312  1814.924362
52       256.0   4096.0  2191.959576  1965.350941  1997.694447  2007.669278
53       256.0   8192.0  2230.292483  2107.956903  2123.508777  2132.732846
block_size=128, num_kv_splits=-1: 
cutlass mla:
    batch_size  seq_len    128 heads     64 heads     32 heads     16 heads
0          1.0      1.0    23.272728    10.112867     7.656885     6.400000
1          1.0     64.0    24.291972    10.033595     7.726652     6.480091
2          1.0    128.0    24.203636     9.933481     7.446762     6.371365
3          1.0    256.0    23.989291    13.581581    11.426279    10.082192
4          1.0    512.0    35.425587    21.886647    20.164384    19.422886
5          1.0   1024.0    56.535978    40.566927    38.577074    37.098730
6          1.0   2048.0    99.796610    75.916381    72.427485    72.425947
7          1.0   4096.0   178.672765   143.757575   137.529409   136.529409
8          1.0   8192.0   287.787678   252.543472   251.131013   250.211318
9          8.0      1.0   193.629092    60.797285    47.193042    39.384617
10         8.0     64.0   194.335773    60.745762    46.625429    39.487000
11         8.0    128.0   195.405498    61.369861    45.645081    37.597359
12         8.0    256.0   174.403899    79.753123    69.094767    62.401209
13         8.0    512.0   260.297360   134.319826   122.849481   117.501882
14         8.0   1024.0   419.981580   255.010064   237.505703   232.122333
15         8.0   2048.0   458.273797   354.104300   344.040783   339.301579
16         8.0   4096.0   829.535224   621.527131   612.931909   608.786897
17         8.0   8192.0  1407.258411  1081.839994  1103.903341  1072.463744
18        32.0      1.0   699.481139   244.433084   183.351343   156.322466
19        32.0     64.0   699.481139   248.457527   181.663595   153.945939
20        32.0    128.0   699.481139   248.242436   187.630084   159.887726
21        32.0    256.0   575.742962   301.511107   258.725334   238.587546
22        32.0    512.0   588.713198   373.891292   343.109267   330.322587
23        32.0   1024.0   927.592888   633.179060   611.792376   593.289696
24        32.0   2048.0  1442.222013  1098.751058  1067.660615  1053.282850
25        32.0   4096.0  2142.847284  1705.707846  1699.760011  1704.341311
26        32.0   8192.0  2945.047613  2514.823601  2518.361783  2526.540472
27        64.0      1.0  1267.809534   468.114280   374.936106   310.514487
28        64.0     64.0  1267.809534   466.211362   361.211328   306.855228
29        64.0    128.0  1267.809534   465.454564   357.640843   303.786652
30        64.0    256.0   715.458523   404.448989   350.523076   316.285845
31        64.0    512.0   985.643571   624.219153   597.062909   575.174952
32        64.0   1024.0  1445.885975  1016.698414  1037.177274  1009.949306
33        64.0   2048.0  2139.354491  1609.896966  1630.700181  1627.540421
34        64.0   4096.0  2812.110256  2353.612536  2393.513643  2395.871000
35        64.0   8192.0  3447.733099  3083.209370  3143.788385  3147.628106
36       128.0      1.0  1832.189286   652.377702   586.328163   530.245821
37       128.0     64.0  1836.137885   652.377702   588.713198   530.245821
38       128.0    128.0  1834.161518   652.377702   577.745828   530.245821
39       128.0    256.0  2047.999945   893.366247   847.530396   810.640418
40       128.0    512.0  2637.363672  1357.073947  1333.581444  1352.766665
41       128.0   1024.0  3139.237983  1941.718095  2007.681725  2078.671713
42       128.0   2048.0  3755.577898  2670.828371  2812.887095  2901.953955
43       128.0   4096.0  4132.412814  3180.265815  3253.974792  3299.831944
44       128.0   8192.0  3723.208330  3191.883251  3437.026621  3308.128904
45       256.0      1.0  2079.238530   799.219521   786.194638   736.080786
46       256.0     64.0  2366.577742   824.500361   776.701242   735.338389
47       256.0    128.0  2366.577742   824.204095   786.550707   736.080786
48       256.0    256.0  2603.586790  1105.828740  1110.145422  1100.928175
49       256.0    512.0  3149.055256  1613.919183  1683.359223  1753.375409
50       256.0   1024.0  3684.588880  2275.457334  2435.837305  2570.974397
51       256.0   2048.0  4047.294114  2762.747788  3075.079007  3054.688465
52       256.0   4096.0  4015.363449  3066.344299  3438.580946  3329.591555
53       256.0   8192.0  4124.473260  3279.465059  3357.222819  3492.303453
Benchmark finished!
  • Before
block_size=1: 
cutlass mla:
    batch_size  seq_len    128 heads
0          1.0      1.0    19.750742
1          1.0     64.0    18.802260
2          1.0    128.0    18.882269
3          1.0    256.0    22.288557
4          1.0    512.0    25.599999
5          1.0   1024.0    27.955829
6          1.0   2048.0    31.141671
7          1.0   4096.0    33.618944
8          1.0   8192.0    35.015109
9          8.0      1.0   153.011494
10         8.0     64.0   157.538456
11         8.0    128.0   152.354787
12         8.0    256.0   178.976282
13         8.0    512.0   204.799993
14         8.0   1024.0   229.706368
15         8.0   2048.0   254.321640
16         8.0   4096.0   271.405611
17         8.0   8192.0   281.467990
18        32.0      1.0   577.214113
19        32.0     64.0   577.214113
20        32.0    128.0   577.214113
21        32.0    256.0   663.703703
22        32.0    512.0   774.622632
23        32.0   1024.0   892.396581
24        32.0   2048.0   982.795846
25        32.0   4096.0  1047.394726
26        32.0   8192.0  1096.086588
27        64.0      1.0  1123.968324
28        64.0     64.0  1122.487459
29        64.0    128.0  1125.453102
30        64.0    256.0  1321.290363
31        64.0    512.0  1464.337329
32        64.0   1024.0  1655.137317
33        64.0   2048.0  1835.007996
34        64.0   4096.0  2001.249562
35        64.0   8192.0  2116.021380
36       128.0      1.0  1442.790820
37       128.0     64.0  1441.570176
38       128.0    128.0  1441.570176
39       128.0    256.0  1592.888865
40       128.0    512.0  1722.066442
41       128.0   1024.0  1877.278458
42       128.0   2048.0  2012.453897
43       128.0   4096.0  2116.751042
44       128.0   8192.0  2184.410635
45       256.0      1.0  1744.049150
46       256.0     64.0  1744.049150
47       256.0    128.0  1740.041858
48       256.0    256.0  1845.341863
49       256.0    512.0  1939.638683
50       256.0   1024.0  2036.558715
51       256.0   2048.0  2120.756951
52       256.0   4096.0  2177.572272
53       256.0   8192.0  2208.681114
block_size=32: 
cutlass mla:
    batch_size  seq_len    128 heads
0          1.0      1.0    18.160982
1          1.0     64.0    18.775741
2          1.0    128.0    18.136240
3          1.0    256.0    21.642512
4          1.0    512.0    25.149213
5          1.0   1024.0    27.887393
6          1.0   2048.0    31.059533
7          1.0   4096.0    33.503541
8          1.0   8192.0    34.998964
9          8.0      1.0   152.791963
10         8.0     64.0   157.538456
11         8.0    128.0   153.895960
12         8.0    256.0   178.530510
13         8.0    512.0   201.380345
14         8.0   1024.0   231.898222
15         8.0   2048.0   254.321640
16         8.0   4096.0   269.009476
17         8.0   8192.0   279.219224
18        32.0      1.0   577.214113
19        32.0     64.0   577.214113
20        32.0    128.0   577.214113
21        32.0    256.0   662.936409
22        32.0    512.0   774.622632
23        32.0   1024.0   891.850768
24        32.0   2048.0   980.968395
25        32.0   4096.0  1039.875160
26        32.0   8192.0  1085.918420
27        64.0      1.0  1149.754392
28        64.0     64.0  1125.453102
29        64.0    128.0  1146.659471
30        64.0    256.0  1292.987596
31        64.0    512.0  1464.337329
32        64.0   1024.0  1654.198491
33        64.0   2048.0  1811.070336
34        64.0   4096.0  1979.449385
35        64.0   8192.0  2088.951287
36       128.0      1.0  1439.135072
37       128.0     64.0  1437.920745
38       128.0    128.0  1437.920745
39       128.0    256.0  1590.679571
40       128.0    512.0  1719.508899
41       128.0   1024.0  1863.483656
42       128.0   2048.0  1996.083162
43       128.0   4096.0  2103.386207
44       128.0   8192.0  2151.576326
45       256.0      1.0  1724.631656
46       256.0     64.0  1725.504893
47       256.0    128.0  1727.254025
48       256.0    256.0  1815.401722
49       256.0    512.0  1922.727907
50       256.0   1024.0  2027.708744
51       256.0   2048.0  2113.429366
52       256.0   4096.0  2161.851591
53       256.0   8192.0  2174.814462
block_size=64: 
cutlass mla:
    batch_size  seq_len    128 heads
0          1.0      1.0    18.828854
1          1.0     64.0    19.181557
2          1.0    128.0    18.802260
3          1.0    256.0    21.668682
4          1.0    512.0    24.849816
5          1.0   1024.0    28.686182
6          1.0   2048.0    31.777950
7          1.0   4096.0    33.925701
8          1.0   8192.0    35.167198
9          8.0      1.0   157.538456
10         8.0     64.0   154.118663
11         8.0    128.0   157.771848
12         8.0    256.0   178.530510
13         8.0    512.0   204.993384
14         8.0   1024.0   232.193636
15         8.0   2048.0   254.419758
16         8.0   4096.0   270.934724
17         8.0   8192.0   280.088564
18        32.0      1.0   577.214113
19        32.0     64.0   577.214113
20        32.0    128.0   577.997268
21        32.0    256.0   667.566922
22        32.0    512.0   774.622632
23        32.0   1024.0   892.396581
24        32.0   2048.0   983.895589
25        32.0   4096.0  1047.833878
26        32.0   8192.0  1096.210288
27        64.0      1.0  1151.308126
28        64.0     64.0  1152.866065
29        64.0    128.0  1143.581167
30        64.0    256.0  1292.987596
31        64.0    512.0  1464.337329
32        64.0   1024.0  1655.137317
33        64.0   2048.0  1836.924802
34        64.0   4096.0  2003.656858
35        64.0   8192.0  2115.790922
36       128.0      1.0  1441.570176
37       128.0     64.0  1439.135072
38       128.0    128.0  1439.135072
39       128.0    256.0  1591.783451
40       128.0    512.0  1721.213083
41       128.0   1024.0  1877.278458
42       128.0   2048.0  2011.878347
43       128.0   4096.0  2117.872501
44       128.0   8192.0  2177.190314
45       256.0      1.0  1729.884177
46       256.0     64.0  1742.265855
47       256.0    128.0  1725.504893
48       256.0    256.0  1840.160530
49       256.0    512.0  1923.260228
50       256.0   1024.0  2036.203147
51       256.0   2048.0  2128.779523
52       256.0   4096.0  2183.161336
53       256.0   8192.0  2197.876366
block_size=128: 
cutlass mla:
    batch_size  seq_len    128 heads
0          1.0      1.0    18.882269
1          1.0     64.0    18.722925
2          1.0    128.0    18.235616
3          1.0    256.0    21.853658
4          1.0    512.0    25.056325
5          1.0   1024.0    28.713296
6          1.0   2048.0    31.777950
7          1.0   4096.0    34.096070
8          1.0   8192.0    35.444317
9          8.0      1.0   157.538456
10         8.0     64.0   157.771848
11         8.0    128.0   157.771848
12         8.0    256.0   178.530510
13         8.0    512.0   204.993384
14         8.0   1024.0   232.045835
15         8.0   2048.0   254.517953
16         8.0   4096.0   272.679249
17         8.0   8192.0   281.386482
18        32.0      1.0   577.997268
19        32.0     64.0   577.214113
20        32.0    128.0   577.997268
21        32.0    256.0   662.170888
22        32.0    512.0   774.622632
23        32.0   1024.0   892.943062
24        32.0   2048.0   991.290490
25        32.0   4096.0  1054.022356
26        32.0   8192.0  1096.272148
27        64.0      1.0  1126.941808
28        64.0     64.0  1146.659471
29        64.0    128.0  1125.453102
30        64.0    256.0  1316.739385
31        64.0    512.0  1464.337329
32        64.0   1024.0  1656.077209
33        64.0   2048.0  1849.158033
34        64.0   4096.0  2013.749813
35        64.0   8192.0  2130.288230
36       128.0      1.0  1444.013533
37       128.0     64.0  1437.920745
38       128.0    128.0  1441.570176
39       128.0    256.0  1591.783451
40       128.0    512.0  1722.920649
41       128.0   1024.0  1879.698289
42       128.0   2048.0  2034.573118
43       128.0   4096.0  2130.513706
44       128.0   8192.0  2188.100281
45       256.0      1.0  1742.265855
46       256.0     64.0  1744.049150
47       256.0    128.0  1721.147525
48       256.0    256.0  1861.062948
49       256.0    512.0  1954.647129
50       256.0   1024.0  2059.207038
51       256.0   2048.0  2135.025534
52       256.0   4096.0  2181.017459
53       256.0   8192.0  2206.173322
Benchmark finished!

Modifications

Checklist

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @Alcanderian, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

Summary of Changes

Hello! Gemini-code-assist has reviewed this pull request. The main purpose of this PR is to enhance the cutlass_mla_decode kernel to properly support attention head counts (num_head) that are less than 128. Previously, this required workarounds. The changes involve modifying the underlying CUDA kernel to handle different page sizes more effectively, updating the Python wrapper to manage input tensors for smaller head counts by padding, adding a new benchmark script to measure performance across various configurations including different head counts, and extending the unit tests to ensure correctness for the newly supported range of head counts. According to the author's benchmarks, these changes result in performance improvements of up to 1.8x for cases where the page_size is 128.

Highlights

  • Extended Head Count Support: The cutlass_mla_decode kernel now supports a number of attention heads (num_head) less than or equal to 128, removing a previous limitation.
  • New Performance Benchmark: A new Triton-based benchmark script (bench_cutlass_mla.py) has been added to measure the performance (GB/s) of the cutlass_mla_decode kernel across various batch sizes, sequence lengths, block sizes, and importantly, different numbers of heads (128, 64, 32, 16).
  • Kernel Adaptation for Page Size: The core CUDA kernel (cutlass_mla_kernel.cu) has been updated to use a template parameter (IsPaged128) to conditionally adjust its behavior (specifically, the kIsCpAsync setting) based on whether the page size is 128. This allows for better optimization across different memory layouts.
  • Python Wrapper Input Handling: The Python interface (attention.py) now checks the number of heads (H) and pads the input query tensor to 128 heads if H < 128 before calling the CUDA kernel. The output is then sliced back to the original H heads.
  • Expanded Unit Tests: Unit tests (test_cutlass_mla.py) have been updated to include parameterization for different numbers of heads (16, 32, 64, 128) to verify the correctness of the kernel across the extended support range.
  • Reported Performance Improvement: The author's benchmark results indicate performance gains, particularly for page_size == 128 configurations with varying head counts, showing up to a 1.8x improvement.

Changelog

Click here to see the changelog
  • sgl-kernel/benchmark/bench_cutlass_mla.py
    • Added a new Triton benchmark script for cutlass_mla_decode (Lines 1-115).
    • Configured the benchmark to test different numbers of heads (128, 64, 32, 16) (Lines 21-31).
    • Included logic to pad block_num for CUTLASS tile alignment based on block_size (Lines 56-59).
    • Calculates and reports performance in GB/s (Lines 83-91).
  • sgl-kernel/csrc/attention/cutlass_mla_kernel.cu
    • Added IsPaged128 template parameter to MlaSm100 struct (Line 58).
    • Made kIsCpAsync conditional on !IsPaged128 in FmhaKernel definition (Line 86).
    • Added IsPaged128 template parameter to runMla function signature (Line 168).
    • Instantiated MlaSm100Type with the IsPaged128 parameter in runMla (Line 177).
    • Added logic in cutlass_mla_decode to check page_size and call runMla with the appropriate IsPaged128 value (Lines 198-215).
    • Fixed MlaSm100Type instantiation in cutlass_mla_get_workspace_size to use IsPaged128=true (Line 221).
  • sgl-kernel/python/sgl_kernel/attention.py
    • Changed the assertion for the number of heads (H) from == 128 to <= 128 (Line 76).
    • Added logic to pad the input query tensor (q_nope_and_q_pe) to 128 heads when H < 128 (Lines 77-80).
    • Changed the output tensor creation to always be size (B_q, 128, D_latent) (Line 104).
    • Sliced the output tensor back to the original number of heads (H) and made it contiguous before returning (Line 109).
  • sgl-kernel/tests/test_cutlass_mla.py
    • Added parameterization for num_heads with values [16, 32, 64, 128] to the test_cutlass_mla_decode function (Line 43).
    • Updated the test_cutlass_mla_decode function signature to accept the num_heads parameter (Lines 45-50).
    • Set the number of query heads (h_q) to the parameterized num_heads value (Line 57).
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces valuable enhancements to cutlass_mla_decode by enabling support for num_head < 128 through Python-level padding and optimizing performance. The benchmark results are promising, and the addition of a benchmark script is valuable.

I have a few questions and suggestions to improve the robustness and clarity of these changes.

Summary of Findings

  • Kernel Internals (CUDA): The rationale for kIsCpAsync = !IsPaged128 (synchronous copy for page_size=128, async otherwise) needs clarification for better understanding of the performance optimization.
  • Workspace Calculation (CUDA): It's important to confirm that hardcoding IsPaged128 = true in cutlass_mla_get_workspace_size is safe and doesn't lead to incorrect workspace allocation if kIsCpAsync (which depends on IsPaged128) affects workspace requirements.
  • Code Maintainability (Benchmark Script): The benchmark script uses magic numbers for dimensions (d, dv) and a fragile if/elif chain for determining head counts, which could be improved for robustness and readability.
  • Code Clarity (Python Wrapper): The magic number 128 for padding head dimensions in attention.py should ideally be a named constant to improve maintainability.

Merge Readiness

The pull request introduces significant improvements. However, due to the high-severity questions regarding the CUDA kernel's workspace size calculation and the rationale behind the kIsCpAsync logic, I recommend addressing these points before merging. The medium-severity suggestions for maintainability would also be good to consider. As a reviewer, I am not authorized to approve the pull request; please ensure further review and approval from other maintainers after addressing the feedback.

@Alcanderian Alcanderian changed the title [perf][sgl-kernel] extend cutlass_mla_decode to support num_head < 128 [WIP][perf][sgl-kernel] extend cutlass_mla_decode to support num_head < 128 Jun 6, 2025
@Alcanderian Alcanderian changed the title [WIP][perf][sgl-kernel] extend cutlass_mla_decode to support num_head < 128 [perf][sgl-kernel] extend cutlass_mla_decode to support num_head < 128 Jun 7, 2025
@Fridge003
Copy link
Collaborator

Fridge003 commented Jun 7, 2025

Is the blocksize in the benchmark results the same meaning as pagesize in sglang?

@Alcanderian
Copy link
Collaborator Author

Is the blocksize in the benchmark results the same meaning as pagesize in sglang?

yes

@Fridge003
Copy link
Collaborator

Is the blocksize in the benchmark results the same meaning as pagesize in sglang?

yes

Nice! I feel we can remove the limitation of page_size=128 for cutlass mla backend

@zhyncs zhyncs merged commit 18efb5e into sgl-project:main Jun 9, 2025
25 of 29 checks passed
walker-ai pushed a commit to walker-ai/sglang that referenced this pull request Jul 8, 2025
Merge branch 'sgl_20250610_sync_tag047 of git@code.alipay.com:Theta/SGLang.git into main

https://code.alipay.com/Theta/SGLang/pull_requests/52


Reviewed-by: 剑川 <jianchuan.gys@antgroup.com>


* [Bugfix] Fix slice operation when chunk size mismatch (sgl-project#6697)
* [Bugfix] Fix ChatCompletion endpoint of mini_lb when stream is set (sgl-project#6703)
* [CI] Fix setup of disaggregation with different tp (sgl-project#6706)
* [PD] Remove Unnecessary Exception Handling for FastQueue.get() (sgl-project#6712)
* Fuse routed_scaling_factor in DeepSeek (sgl-project#6710)
* Overlap two kernels in DeepSeek with communication (sgl-project#6711)
* Minor refactor two-batch overlap (sgl-project#6682)
* Speed up when having padding tokens two-batch overlap (sgl-project#6668)
* [Feature] Support Flashinfer fp8 blockwise GEMM kernel on Blackwell (sgl-project#6479)
* Fix LoRA bench (sgl-project#6719)
* temp
* Fix PP for Qwen3 MoE (sgl-project#6709)
* [feat] triton kernel for get_last_loc (sgl-project#6676)
* [fix] more mem for draft_extend cuda_graph (sgl-project#6726)
* [PD] bug fix:  Update status if nixl receiver send a a dummy req. (sgl-project#6720)
* Tune memory arguments on B200 (sgl-project#6718)
* Add DeepSeek-R1-0528 function call chat template (sgl-project#6725)
* refactor(tool call): Fix BaseFormatDetector tool_index issue and refactor `parse_streaming_increment` (sgl-project#6715)
* Add draft extend CUDA graph for Triton backend (sgl-project#6705)
* refactor apply_w8a8_block_fp8_linear in fp (sgl-project#6545)
* [PD] Support completion endpoint (sgl-project#6729)
* PD Rust LB (PO2) (sgl-project#6437)
* Super tiny enable sole usage of expert distribution metrics and update doc (sgl-project#6680)
* Support picking variants of EPLB algorithms (sgl-project#6728)
* Support tuning DeepEP configs (sgl-project#6742)
* [test] add ut and bm for get_last_loc (sgl-project#6746)
* Fix mem_fraction_static for AMD CI (sgl-project#6748)
* [fix][RL] Fix DeepSeekV3ForCausalLM.post_load_weights for multiple update weight (sgl-project#6265)
* Improve EPLB logical to physical dispatch map (sgl-project#6727)
* Update DeepSeek-R1-0528 function call chat template (sgl-project#6765)
* [PD] Optimize time out logic and add env var doc for mooncake (sgl-project#6761)
* Fix aiohttp 'Chunk too big' in bench_serving (sgl-project#6737)
* Support sliding window in triton backend (sgl-project#6509)
* Fix shared experts fusion error (sgl-project#6289)
* Fix one bug in the grouped-gemm triton kernel (sgl-project#6772)
* update llama4 chat template and pythonic parser (sgl-project#6679)
* feat(tool call): Enhance Llama32Detector for improved JSON parsing in non-stream (sgl-project#6784)
* Support token-level quantization for EP MoE (sgl-project#6782)
* Temporarily lower mmlu threshold for triton sliding window backend (sgl-project#6785)
* ci: relax test_function_call_required (sgl-project#6786)
* Add intel_amx backend for Radix Attention for CPU (sgl-project#6408)
* Fix incorrect LoRA weight loading for fused gate_up_proj (sgl-project#6734)
* fix(PD-disaggregation): Can not get local ip (sgl-project#6792)
* [FIX] mmmu bench serving result display error (sgl-project#6525) (sgl-project#6791)
* Bump torch to 2.7.0 (sgl-project#6788)
* chore: bump sgl-kernel v0.1.5 (sgl-project#6794)
* Improve profiler and integrate profiler in bench_one_batch_server (sgl-project#6787)
* chore: upgrade sgl-kernel v0.1.5 (sgl-project#6795)
* [Minor] Always append newline after image token when parsing chat message (sgl-project#6797)
* Update CI tests for Llama4 models (sgl-project#6421)
* [Feat] Enable PDL automatically on Hopper architecture (sgl-project#5981)
* chore: update blackwell docker (sgl-project#6800)
* misc: cache is_hopper_arch (sgl-project#6799)
* Remove contiguous before Flashinfer groupwise fp8 gemm (sgl-project#6804)
* Correctly abort the failed grammar requests & Improve the handling of abort (sgl-project#6803)
* [EP] Add cuda kernel for moe_ep_pre_reorder (sgl-project#6699)
* Add draft extend CUDA graph for flashinfer backend  (sgl-project#6805)
* Refactor CustomOp to avoid confusing bugs (sgl-project#5382)
* Tiny log prefill time (sgl-project#6780)
* Tiny fix EPLB assertion about rebalancing period and recorder window size (sgl-project#6813)
* Add simple utility to dump tensors for debugging (sgl-project#6815)
* Fix profiles do not have consistent names (sgl-project#6811)
* Speed up rebalancing when using non-static dispatch algorithms (sgl-project#6812)
* [1/2] Add Kernel support for Cutlass based Fused FP4 MoE (sgl-project#6093)
* [Router] Fix k8s Service Discovery (sgl-project#6766)
* Add CPU optimized kernels for topk and rope fusions  (sgl-project#6456)
* fix new_page_count_next_decode (sgl-project#6671)
* Fix wrong weight reference in dynamic EPLB (sgl-project#6818)
* Minor add metrics to expert location updater (sgl-project#6816)
* [Refactor] Rename `n_share_experts_fusion` as `num_fused_shared_experts` (sgl-project#6735)
* [FEAT] Add transformers backend support  (sgl-project#5929)
* [fix] recover auto-dispatch for rmsnorm and rope (sgl-project#6745)
* fix ep_moe_reorder kernel bugs (sgl-project#6858)
* [Refactor] Multimodal data processing for VLM (sgl-project#6659)
* Decoder-only Scoring API (sgl-project#6460)
* feat: add dp-rank to KV events (sgl-project#6852)
* Set `num_fused_shared_experts` as `num_shared_experts` when shared_experts fusion is not disabled (sgl-project#6736)
* Fix one missing arg in DeepEP (sgl-project#6878)
* Support LoRA in TestOpenAIVisionServer and fix fused kv_proj loading bug. (sgl-project#6861)
* support 1 shot allreduce  in 1-node and 2-node using mscclpp (sgl-project#6277)
* Fix Qwen3MoE missing token padding optimization (sgl-project#6820)
* Tiny update error hints (sgl-project#6846)
* Support layerwise rebalancing experts (sgl-project#6851)
* Tiny allow profiler API to auto create directory (sgl-project#6865)
* Support Blackwell DeepEP docker images (sgl-project#6868)
* [EP] Add cuda kernel for moe_ep_post_reorder (sgl-project#6837)
* [theta]merge 0605
* oai: fix openAI client error with single request via batch api (sgl-project#6170)
* [PD] Fix potential perf spike caused by tracker gc and optimize doc (sgl-project#6764)
* Use deepgemm instead of triton for fused_qkv_a_proj_with_mqa (sgl-project#6890)
* [CUTLASS-FP4-MOE]  Introduce CutlassMoEParams class for easy initialization of Cutlass Grouped Gems Metadata (sgl-project#6887)
* bugfix(OAI): Fix image_data processing for jinja chat templates (sgl-project#6877)
* [CPU] enable CI for PRs, add Dockerfile and auto build task (sgl-project#6458)
* AITER backend extension and workload optimizations (sgl-project#6838)
* [theta]merge
* [theta]merge
* [Feature] Support Flashinfer fmha on Blackwell (sgl-project#6930)
* Fix a bug in abort & Improve docstrings for abort (sgl-project#6931)
* Tiny support customize DeepEP max dispatch tokens per rank (sgl-project#6934)
* Sync the changes on cuda graph runners (sgl-project#6932)
* [PD] Optimize transfer queue forward logic for dummy rank (sgl-project#6922)
* [Refactor] image data process in bench_serving (sgl-project#6879)
* [fix] logical_to_all_physical_map index 256 is out of bounds in EP parallel. (sgl-project#6767)
* Add triton fused moe kernel config for E=257 on B200 (sgl-project#6939)
* [sgl-kernel] update deepgemm (sgl-project#6942)
* chore: bump sgl-kernel v0.1.6 (sgl-project#6943)
* Minor compile fused topk (sgl-project#6944)
* [Bugfix] pipeline parallelism and Eagle Qwen2 (sgl-project#6910)
* Tiny re-introduce profile id logging (sgl-project#6912)
* Add triton version as a fused_moe_triton config search key to avoid performace decrease in different Triton version (sgl-project#5955)
* reduce torch.zeros overhead in moe align block size kernel (sgl-project#6369)
* chore: upgrade sgl-kernel v0.1.6 (sgl-project#6945)
* add fbgemm moe grouped gemm kernel benchmark (sgl-project#6924)
* [Docker] Add docker file for SGL Router (sgl-project#6915)
* Disabling mixed chunked prefill when eagle is enabled (sgl-project#6874)
* Add canary for EPLB rebalancing (sgl-project#6895)
* Refactor global_server_args_dict (sgl-project#6866)
* Fuse routed scaling factor in topk_reduce kernel (sgl-project#6220)
* Update server timeout time in AMD CI. (sgl-project#6953)
* [misc] add is_cpu() (sgl-project#6950)
* Add H20 fused MoE kernel tuning configs for DeepSeek-R1/V3 (sgl-project#6885)
* Add a CUDA kernel for fusing mapping and weighted sum for MoE. (sgl-project#6916)
* chore: bump sgl-kernel v0.1.6.post1 (sgl-project#6955)
* chore: upgrade sgl-kernel v0.1.6.post1 (sgl-project#6957)
* [DeepseekR1-FP4] Add Support for nvidia/DeepSeekR1-FP4 model (sgl-project#6853)
* Revert "Fuse routed scaling factor in topk_reduce kernel (sgl-project#6220)" (sgl-project#6968)
* [AMD] Add more tests to per-commit-amd (sgl-project#6926)
* chore: bump sgl-kernel v0.1.7 (sgl-project#6963)
* Slightly improve the sampler to skip unnecessary steps (sgl-project#6956)
* rebase h20 fused_moe config (sgl-project#6966)
* Fix CI and triton moe Configs (sgl-project#6974)
* Remove unnecessary kernels of num_token_non_padded (sgl-project#6965)
* Extend cuda graph capture bs for B200 (sgl-project#6937)
* Fuse routed scaling factor in deepseek (sgl-project#6970)
* Sync cuda graph runners (sgl-project#6976)
* Fix draft extend ut stability with flush cache (sgl-project#6979)
* Fix triton sliding window test case (sgl-project#6981)
* Fix expert distribution dumping causes OOM (sgl-project#6967)
* Minor remove one kernel for DeepSeek (sgl-project#6977)
* [perf][sgl-kernel] extend cutlass_mla_decode to support num_head < 128 (sgl-project#6929)
* Enable more unit tests for AMD CI. (sgl-project#6983)
* Use torch.compile to fuse flash attention decode metadata preparation (sgl-project#6973)
* Eliminate stream sync to speed up LoRA batch init  (sgl-project#6960)
* support qwen3 emebedding (sgl-project#6990)
* Fix torch profiler bugs for bench_offline_throughput.py (sgl-project#6557)
* chore: upgrade flashinfer v0.2.6.post1 jit (sgl-project#6958)
* cleanup tmp dir (sgl-project#7007)
* chore: update pr test xeon (sgl-project#7008)
* Fix cutlass MLA gets almost zero accuracy (sgl-project#6998)
* Update amd nightly models CI. (sgl-project#6992)
* feat: add direct routing strategy to DP worker (sgl-project#6884)
* Fallback to lower triton version for unfound fused moe configs (sgl-project#7013)
* Fix torchvision version for Blackwell (sgl-project#7015)
* Simplify prepare_extend_after_decode (sgl-project#6987)
* Migrate to assertEqual (sgl-project#6741)
* Fix torch version in blackwell dockerfile (sgl-project#7017)
* chore: update pr test xeon (sgl-project#7018)
* Update default settings for blackwell (sgl-project#7023)
* Support both approximate and exact expert distribution collection (sgl-project#6964)
* Add decode req pool (sgl-project#6980)
* [theta]merge 0610
* [theta]merge 0610
* [CI] Add CI workflow for sgl-router docker build (sgl-project#7027)
* Fix fused_moe triton configs (sgl-project#7029)
* CPU: map changes from developing branch in sgl-kernel (sgl-project#6833)
* chore: bump v0.4.7 (sgl-project#7038)
* Update README.md (sgl-project#7040)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants