Skip to content

Conversation

awni
Copy link
Member

@awni awni commented Dec 3, 2024

Adds export_function and import_function so that we can save and load functions from a file. Makes it possible to use functions written in one language from another language (e.g. Python -> C++).

Basically works like so:

In Python:

# Note, the model parameters are saved in the export function
# An alternative is to make them inputs to forward
def forward(x):
    return model(x)

example_x = mx.zeros(shape=(batch_size, input_dim))

# Export to file using example input
mx.export_function("model.mlxfn", forward, example_x)

Then in C++, for example:

  auto example_x = random::uniform({batch_size, input_dim});

  // Import the function
  auto forward = import_function("model.mlxfn");

  // Call the imported function
  auto out = forward({example_x})[0];

Some notes on the implementation:

  • Reuses a lot of the compile infrastructure which simplifies things dramatically
  • The serialization of everything is mostly decoupled from the rest of the code and kept in the export.cpp
  • Serializing primitives that have member variables requires some way of accessing them. The API is not opinionated about this (so the primitive interface didn't change at all).. but the convention I'm using is to have a state which returns the data to save
  • Likely can use templates / preprocessor to reduce more boiler-plate from some of serialization code in export.cpp. But didn't want to obfuscate / over engineer it much yet until getting some input.

@angeloskath
Copy link
Member

This is massively cool. I 'll get to reviewing asap!

// - constants, which can be used directly
// - a load primitive which has no inputs and will become a constant
// after the first eval
if (!a.has_primitive() || is_load(a.primitive())) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is worth commenting on:

  • Previously if you loaded arrays from a file inside a compiled function then every call to the function would reload from the file.
  • Now only the first call to the function loads and after that the loaded arrays become constants in the tape

This seems better to me.. though perhaps that is debatable. It is also used by import_function which makes Load primitives for constants and so we get lazy loading even with import_function which is pretty nice.

Lmk thoughts.. I can switch it so compile doesn't force load (more flexible but more dangerous).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was a bit torn but the more I think about it the more I like it. If someone wants to control the load, they can always pass the state as inputs.

Managing memory is also as easy as before. All you have to do is delete the function after it is called but before you eval (there is the small overhead of deserializing the function when loading again).

Copy link
Member Author

@awni awni Dec 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It got me thinking about an optimization we can do in compile in general which is to prune/eval branches of the graph that have no inputs on the leaves. It would be a trade-off for memory and compute so one would need to be careful there.. but there are some legitimate cases that it's come up for me.

For example in some of our RoPE implementations we precompute the self._freqs but they aren't part of the model.parameters(). Naively compiling would recompute those at each call of the function. So you have to make sure you eval them before running the compiled function.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm interesting. It would be fairly easy to write as a part of simplify or a similar operation but how would we provide that functionality? Always doing it doesn't seem like the best option as it might require a lot of memory which the user explicitly doesn't want to keep around. Same goes for the Load indeed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One way of deciding is doing it based on user expectations. I'm not sure it's technically feasible. But there are two cases that we currently treat the same that I think people have quite different expectations about. Use load as an example:

def fun1(x):
  return x + mx.load("y.safetensors")["y"]

y = mx.load("y.safetensors")["y"]

def fun2(x):
  return x + y

In fun1 I expect the load to happen every time I call the function (as it does in eager mode). In fun2 I expect the load to happen only once.

For compile the behavior used to be like fun1 and I switched it to fun2 just because it's really unusual to write something like fun1.

But in theory we could try and distinguish the two cases. I think the expecatios are the same for any computation:

def fun1(x):
  return x + complex_fun()

y = complex_fun()

def fun2(x):
  return x + y

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Getting that behavior in Python is probably doable because we can figure out which inputs are enclosed.. getting it in C++ is probably a lot harder / maybe not possible.

mlx/export.h Outdated
Comment on lines 21 to 22
std::function<std::vector<array>(const std::vector<array>&)> import_function(
std::string path);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another question: should import_function return metadata? I can see how it would be useful to get say the shapes and/or dtypes of the inputs, maybe the MLX version, etc in a dict of metadata. Can also wait and see and provide an overload / a return_metadata flag in the future.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine. I am not too sure why we want to put everything in a single file tbh (weights, metadata and computation graph) but either way, as you say, we can always add a return_metadata optional argument.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#2410
requests for adding meta data.

Is there a plan to support it?

@@ -2098,21 +2147,6 @@ class Tanh : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};

class Uniform : public UnaryPrimitive {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused 🤷‍♂️ ..

@awni awni force-pushed the export_import branch 3 times, most recently from fd6520c to 454f44c Compare December 12, 2024 17:03
@awni awni force-pushed the export_import branch 5 times, most recently from 65b15ad to 28291dc Compare December 23, 2024 14:40
@awni
Copy link
Member Author

awni commented Dec 23, 2024

I updated the exporting API to allow functions in C++ which take a vector and/or map of keyword arguments:

So you can do things like:

export_function(file_path, fun, args);
export_function(file_path, fun, kwargs);
export_function(file_path, fun, args, kwargs);

And similarly:

auto fun = import_function(file_path);
fun(args);
fun(args, kwargs);
fun(kwargs);

args is a vector of arrays and kwargs is a map of string keys with array values.

This makes it a lot safe to export functions that take keyword arguments from Python.

@awni
Copy link
Member Author

awni commented Dec 23, 2024

I also updated the API to allow exporting multiple traces of the same function with varying inputs. This is really nice e.g. for LLM inference where prompt prefill takes a mask and sometimes cache but generation does not take a mask. It is also pretty nice for exporting varying shapes (when you aren't doing shapeless exports). Common constants are serialized only once.

For example:

        constant = mx.zeros((16, 2048))
        mx.eval(constant)
        
        def fun(*args):
            return constant + sum(args)
        
        with mx.exporter(path, fun) as exporter:
            for i in range(5):
                exporter(*[mx.array(1)] * i)

The above exports 6 different graphs but only a single copy of constant.

@awni awni marked this pull request as ready for review December 23, 2024 16:14
@awni
Copy link
Member Author

awni commented Dec 23, 2024

I think it's safe to review this.. and I probably should not keep growing this diff because it's getting a bit large.

I still want to clean-up some of the implementation but overall I think it gets the job done reasonably well.

@awni awni changed the title [WIP] Export / import functions to / from a file Export / import functions to / from a file Dec 23, 2024
@awni
Copy link
Member Author

awni commented Dec 23, 2024

Here is a little git package that exports Llama 3.1 generation from Python and imports and runs it C++: https://gist.github.com/awni/ebd1c9faa0e33c5d924561695c15ac7e

Copy link
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks really really good. There is nothing I could find to comment really 🤷‍♂️. There is some amount of domain specific logic in FunctionExporter::export_function but I think this aligns perfectly with the rest of the code. It is more efficient and relatively self-contained similar to eval, vjp etc.

Passing nullptr for the fallback of fast primitives is a bit of a hairy situation. One option would be to export the fallback tape on the provided inputs but a) it is a bit complicated b) unlikely to be useful.

// - constants, which can be used directly
// - a load primitive which has no inputs and will become a constant
// after the first eval
if (!a.has_primitive() || is_load(a.primitive())) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was a bit torn but the more I think about it the more I like it. If someone wants to control the load, they can always pass the state as inputs.

Managing memory is also as easy as before. All you have to do is delete the function after it is called but before you eval (there is the small overhead of deserializing the function when loading again).

@awni
Copy link
Member Author

awni commented Dec 24, 2024

Passing nullptr for the fallback of fast primitives is a bit of a hairy situation. One option would be to export the fallback tape on the provided inputs but a) it is a bit complicated b) unlikely to be useful.

I agree that is probably the "correct" option. And I think it probably should be done but like you say it's a bit involved and I didn't want to keep growing this diff. The nice thing of doing it that way is we would be able to transform fast primitives even after export -> import which I think is pretty neat (though maybe not so useful in practice 😅 ).

@awni
Copy link
Member Author

awni commented Dec 24, 2024

Thanks for reading the diff @angeloskath! I know it's a big one!! I'm going to mark the new APIs as experimental in the doc strings (just to give fair warning that they are still quite new and subject to change) and then I think we should merge and refine with some license to change what we need based on usage. I'll also follow up this PR with a usage guide for the docs which we can put it either in developer docs or in our usage section.

@angeloskath
Copy link
Member

Sounds perfect to me!

@awni awni merged commit 4ba0c24 into main Dec 24, 2024
5 checks passed
@awni awni deleted the export_import branch December 24, 2024 19:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants