Skip to content

Conversation

awni
Copy link
Member

@awni awni commented Jan 8, 2025

Add optional boolean mask in vector SDPA metal kernel:

Timing sdpa ... 3.01268 msec
Timing attention ... 7.13981 msec
Timing attention_mask ... 7.50264 msec
Timing sdpa_mask ... 2.24858 msec

Using the mask + sdpa is a lot faster than the naive primitive with a mask.

@awni awni requested review from angeloskath and barronalex January 8, 2025 01:15
Copy link
Contributor

@barronalex barronalex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That’s an awesome speed up!!
Super clean with the function constant too 🚀

score = simd_sum(score);
// Compute the i-th score
U score = 0;
for (int j = 0; j < elem_per_thread; j++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for removing the shadowing on i here

@awni awni force-pushed the bool_mask_vector_sdpa branch from 970314a to 7ae5624 Compare January 8, 2025 03:50
@awni awni merged commit d1766f2 into main Jan 8, 2025
5 checks passed
@awni awni deleted the bool_mask_vector_sdpa branch January 8, 2025 04:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants