Skip to content

Conversation

awni
Copy link
Member

@awni awni commented Mar 30, 2025

Add this to speedup and reduce memory use when computing logsumexp during e.g. lora fine tuning. The large vocab (200k+ in some cases) makes this optimization important. Also since the logsumexp is in high precision this let's us avoid up casting the logits and instead only upcasting the per-token loss prior to the reduction.

QLoRA fine tuning Gemma 3 1B:

Pre: Iter 10: Train loss 3.595, Learning Rate 1.000e-05, It/sec 0.227, Tokens/sec 812.842, Trained Tokens 35845, Peak mem 66.335 GB

Post: Iter 10: Train loss 3.594, Learning Rate 1.000e-05, It/sec 0.345, Tokens/sec 1235.568, Trained Tokens 35845, Peak mem 49.164 GB

And a microbenchmark:

shape = (4096, 120_000)

x = mx.random.uniform(shape=shape)

def fun(x):
    return [mx.logsumexp(x, axis=-1, keepdims=True) for _ in range(100)]

Pre: 1588.345
Post: 257.241

@awni awni force-pushed the custom_logsumexp branch from e292e8b to bc91df8 Compare March 30, 2025 23:35
@awni awni force-pushed the custom_logsumexp branch from bc91df8 to aadc758 Compare March 30, 2025 23:59
Copy link
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! The memory savings look even greater :-)

Left one nitpick...

#include "mlx/backend/metal/kernels/utils.h"
#include "mlx/backend/metal/kernels/logsumexp.h"

#define instantiate_logsumexp(name, itype) \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I think I forgot to comment it before (it's a nitpick anyway)...

Maybe use instantiate_kernel here?

@awni awni merged commit de5f38f into main Mar 31, 2025
4 checks passed
@awni awni deleted the custom_logsumexp branch March 31, 2025 14:36
# Large
x = mx.random.uniform(shape=(1025,))
x = mx.broadcast_to(mx.random.uniform(shape=(2, 1, 8)), (2, 2, 8))
self.assertTrue(mx.allclose(mx.logsumexp(x), logsumexp(x)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assertion tests the same thing with previous one.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea looks like a mistake

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants