Skip to content

Conversation

awni
Copy link
Member

@awni awni commented Mar 4, 2025

Benchmarks on M4 max with the following config:

L = 16384
Lq = 4
H = 32
H_k = H // 4
D = 128
V = 128

Timing sdpa ... 4.70581 msec
Timing attention ... 16.99894 msec

Also updated RoPE to route to a faster path for shapes like [1, H, L, D].

The intention here is mostly to speed up spec dec. Will share some benchmarks here: ml-explore/mlx-examples#1319

Copy link
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!

The choice for query_transposed vs passing strides and so on was for performance reasons?

@awni
Copy link
Member Author

awni commented Mar 4, 2025

The choice for query_transposed vs passing strides and so on was for performance reasons?

Good question.. I didn't try passing the query strides in so I couldn't say if there is much perf difference, probably minor.. the diff seemed simpler with the function constant. On the other hand it could be more general to allow arbitrary strides in the sequence and head dimension.

@awni awni merged commit e613d0e into main Mar 4, 2025
5 checks passed
@awni awni deleted the batch_query_sdpa branch March 4, 2025 18:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants