-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Export / import functions to / from a file #1642
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This is massively cool. I 'll get to reviewing asap! |
// - constants, which can be used directly | ||
// - a load primitive which has no inputs and will become a constant | ||
// after the first eval | ||
if (!a.has_primitive() || is_load(a.primitive())) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is worth commenting on:
- Previously if you loaded arrays from a file inside a compiled function then every call to the function would reload from the file.
- Now only the first call to the function loads and after that the loaded arrays become constants in the tape
This seems better to me.. though perhaps that is debatable. It is also used by import_function
which makes Load
primitives for constants and so we get lazy loading even with import_function
which is pretty nice.
Lmk thoughts.. I can switch it so compile doesn't force load (more flexible but more dangerous).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was a bit torn but the more I think about it the more I like it. If someone wants to control the load, they can always pass the state as inputs.
Managing memory is also as easy as before. All you have to do is delete the function after it is called but before you eval (there is the small overhead of deserializing the function when loading again).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It got me thinking about an optimization we can do in compile in general which is to prune/eval branches of the graph that have no inputs on the leaves. It would be a trade-off for memory and compute so one would need to be careful there.. but there are some legitimate cases that it's come up for me.
For example in some of our RoPE implementations we precompute the self._freqs
but they aren't part of the model.parameters()
. Naively compiling would recompute those at each call of the function. So you have to make sure you eval them before running the compiled function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm interesting. It would be fairly easy to write as a part of simplify or a similar operation but how would we provide that functionality? Always doing it doesn't seem like the best option as it might require a lot of memory which the user explicitly doesn't want to keep around. Same goes for the Load
indeed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One way of deciding is doing it based on user expectations. I'm not sure it's technically feasible. But there are two cases that we currently treat the same that I think people have quite different expectations about. Use load as an example:
def fun1(x):
return x + mx.load("y.safetensors")["y"]
y = mx.load("y.safetensors")["y"]
def fun2(x):
return x + y
In fun1
I expect the load to happen every time I call the function (as it does in eager mode). In fun2
I expect the load to happen only once.
For compile the behavior used to be like fun1
and I switched it to fun2
just because it's really unusual to write something like fun1
.
But in theory we could try and distinguish the two cases. I think the expecatios are the same for any computation:
def fun1(x):
return x + complex_fun()
y = complex_fun()
def fun2(x):
return x + y
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Getting that behavior in Python is probably doable because we can figure out which inputs are enclosed.. getting it in C++ is probably a lot harder / maybe not possible.
mlx/export.h
Outdated
std::function<std::vector<array>(const std::vector<array>&)> import_function( | ||
std::string path); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another question: should import_function
return metadata? I can see how it would be useful to get say the shapes and/or dtypes of the inputs, maybe the MLX version, etc in a dict of metadata. Can also wait and see and provide an overload / a return_metadata
flag in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's fine. I am not too sure why we want to put everything in a single file tbh (weights, metadata and computation graph) but either way, as you say, we can always add a return_metadata
optional argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#2410
requests for adding meta data.
Is there a plan to support it?
@@ -2098,21 +2147,6 @@ class Tanh : public UnaryPrimitive { | |||
void eval(const std::vector<array>& inputs, array& out); | |||
}; | |||
|
|||
class Uniform : public UnaryPrimitive { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unused 🤷♂️ ..
fd6520c
to
454f44c
Compare
65b15ad
to
28291dc
Compare
I updated the exporting API to allow functions in C++ which take a vector and/or map of keyword arguments: So you can do things like: export_function(file_path, fun, args);
export_function(file_path, fun, kwargs);
export_function(file_path, fun, args, kwargs); And similarly: auto fun = import_function(file_path);
fun(args);
fun(args, kwargs);
fun(kwargs);
This makes it a lot safe to export functions that take keyword arguments from Python. |
I also updated the API to allow exporting multiple traces of the same function with varying inputs. This is really nice e.g. for LLM inference where prompt prefill takes a mask and sometimes cache but generation does not take a mask. It is also pretty nice for exporting varying shapes (when you aren't doing shapeless exports). Common constants are serialized only once. For example: constant = mx.zeros((16, 2048))
mx.eval(constant)
def fun(*args):
return constant + sum(args)
with mx.exporter(path, fun) as exporter:
for i in range(5):
exporter(*[mx.array(1)] * i) The above exports 6 different graphs but only a single copy of |
I think it's safe to review this.. and I probably should not keep growing this diff because it's getting a bit large. I still want to clean-up some of the implementation but overall I think it gets the job done reasonably well. |
Here is a little git package that exports Llama 3.1 generation from Python and imports and runs it C++: https://gist.github.com/awni/ebd1c9faa0e33c5d924561695c15ac7e |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks really really good. There is nothing I could find to comment really 🤷♂️. There is some amount of domain specific logic in FunctionExporter::export_function
but I think this aligns perfectly with the rest of the code. It is more efficient and relatively self-contained similar to eval
, vjp
etc.
Passing nullptr
for the fallback of fast primitives is a bit of a hairy situation. One option would be to export the fallback tape on the provided inputs but a) it is a bit complicated b) unlikely to be useful.
// - constants, which can be used directly | ||
// - a load primitive which has no inputs and will become a constant | ||
// after the first eval | ||
if (!a.has_primitive() || is_load(a.primitive())) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was a bit torn but the more I think about it the more I like it. If someone wants to control the load, they can always pass the state as inputs.
Managing memory is also as easy as before. All you have to do is delete the function after it is called but before you eval (there is the small overhead of deserializing the function when loading again).
I agree that is probably the "correct" option. And I think it probably should be done but like you say it's a bit involved and I didn't want to keep growing this diff. The nice thing of doing it that way is we would be able to transform fast primitives even after export -> import which I think is pretty neat (though maybe not so useful in practice 😅 ). |
Thanks for reading the diff @angeloskath! I know it's a big one!! I'm going to mark the new APIs as experimental in the doc strings (just to give fair warning that they are still quite new and subject to change) and then I think we should merge and refine with some license to change what we need based on usage. I'll also follow up this PR with a usage guide for the docs which we can put it either in developer docs or in our usage section. |
Sounds perfect to me! |
Adds
export_function
andimport_function
so that we can save and load functions from a file. Makes it possible to use functions written in one language from another language (e.g. Python -> C++).Basically works like so:
In Python:
Then in C++, for example:
Some notes on the implementation:
export.cpp
state
which returns the data to saveexport.cpp
. But didn't want to obfuscate / over engineer it much yet until getting some input.