-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Custom logsumexp #2028
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Custom logsumexp #2028
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! The memory savings look even greater :-)
Left one nitpick...
#include "mlx/backend/metal/kernels/utils.h" | ||
#include "mlx/backend/metal/kernels/logsumexp.h" | ||
|
||
#define instantiate_logsumexp(name, itype) \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I think I forgot to comment it before (it's a nitpick anyway)...
Maybe use instantiate_kernel
here?
# Large | ||
x = mx.random.uniform(shape=(1025,)) | ||
x = mx.broadcast_to(mx.random.uniform(shape=(2, 1, 8)), (2, 2, 8)) | ||
self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assertion tests the same thing with previous one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea looks like a mistake
Add this to speedup and reduce memory use when computing
logsumexp
during e.g. lora fine tuning. The large vocab (200k+ in some cases) makes this optimization important. Also since thelogsumexp
is in high precision this let's us avoid up casting the logits and instead only upcasting the per-token loss prior to the reduction.QLoRA fine tuning Gemma 3 1B:
Pre:
Iter 10: Train loss 3.595, Learning Rate 1.000e-05, It/sec 0.227, Tokens/sec 812.842, Trained Tokens 35845, Peak mem 66.335 GB
Post:
Iter 10: Train loss 3.594, Learning Rate 1.000e-05, It/sec 0.345, Tokens/sec 1235.568, Trained Tokens 35845, Peak mem 49.164 GB
And a microbenchmark:
Pre: 1588.345
Post: 257.241