-
-
Notifications
You must be signed in to change notification settings - Fork 4.8k
Added JaxPrinter #23627
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added JaxPrinter #23627
Conversation
✅ Hi, I am the SymPy bot (v167). I'm here to help you write a release notes entry. Please read the guide on how to write release notes. Your release notes are in good order. Here is what the release notes will look like:
This will be added to https://github.com/sympy/sympy/wiki/Release-Notes-for-1.11. Click here to see the pull request description that was parsed.
Update The release notes on the wiki have been updated. |
🟠Hi, I am the SymPy bot (v167). I've noticed that some of your commits add or delete files. Since this is sometimes done unintentionally, I wanted to alert you about it. This is an experimental feature of SymPy Bot. If you have any feedback on it, please comment at sympy/sympy-bot#75. The following commits add new files:
If these files were added/deleted on purpose, you can ignore this message. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for this contribution, I had a quick look and it looks really nice.
The Jax printer looks good, the parametric
part I'm less sure about: isn't this approach fragile to e.g. reordering of the arguments? What is the benefit over simply preprocessing the expression and introducing parameter symbols manually?
Would it perhaps be possible to split the PR into two parts, first the Jax printer, and then a second PR for the parametric part?
Let's see what the others say.
I also don't understand the parametric part. If you internally replace all numeric nodes with symbols, how would the user know how to provide values in the resulting numeric function in order and consistently. For example, what if there are two 1's in the expression? Are you replacing those with the same symbol or two different symbols and, if two, what if the user then provides different values for them? Maybe a description of the use case for parametric would help clear it up. |
This sounds like a good idea. |
Benchmark results from GitHub Actions Lower numbers are good, higher numbers are bad. A ratio less than 1 Significantly changed benchmark results (PR vs master) Significantly changed benchmark results (master vs previous release) before after ratio
[77f1d79c] [a28a214e]
<sympy-1.10.1^0>
+ 100±0.8ms 188±3ms 1.87 sum.TimeSum.time_doit
Full benchmark results can be found as artifacts in GitHub Actions |
Is the parametric part in a separate PR now? Now that the parametric part has been removed from this PR can someone review this? It looks like a good addition for 1.11. |
Also there's a simple merge conflict. |
The parameteric pull has been moved to #23661 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made a second pass of reviewing. Some minor comments/questions.
sympy/printing/tests/test_jax.py
Outdated
from sympy.core.symbol import symbols | ||
n = symbols('n', integer=True) | ||
N = MatrixSymbol("M", n, n) | ||
raises(NotImplementedError, lambda: lambdify(N, N + Identity(n))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is JAX involved here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test has been copied over from the NumpyPrinter
tests. I must have forgot to change the lambdify expression to call the JaxPrinter
instead. I will update the tests to rectify this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated in commit 4b3869b
I'm removing the 1.11 milestone. It would be good to get this in to the 1.11 release and if it is merged soon then it will be. Otherwise there's no particular reason for this to be a release blocker. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks ready from my point of view.
Thank you @tttc3 for this contribution.
Folks over here may be interested in sympy2jax, which also aims to do the sympy->JAX conversion (and back again). I think the main differences to over here are (a) parameter handling and (b) there's no printing; it's just a normal JAX construct. |
Thanks @patrick-kidger. Would it make sense to include sympy2jax in SymPy itself? Would you say it's strictly better than what's included here? I can imagine benefits to the printing approach (like being easier to debug), but I suspect sympy2jax approach is better in most cases. |
IMO it is strictly better (I made it for a reason after all) but obviously I'm not an unbiased observer here. I'm don't think it makes sense to bake into SymPy itself. It make an opionated choice of library to build on: Equinox as opposed to Flax or Haiku. And longer-term I don't think SymPy should need to implement a SymPy-to-Foo for those Foo that have yet to be released. Delegating to to some external library has worked so far for PyTorch. |
References to other Issues or PRs
Adds JAX
and parametric printing supportas discussed in #20516Parametric support has been moved to #23661.
Brief description of what is fixed or changed
Allows
lambdify
to output JAX compatible functions.and provides parametric support for all children ofAbstractPythonCodePrinter
.The parametric option attempts to provide native support for the problems that @patrick-kidger'ssympytorch
and @MilesCranmersympy2jax
packages solve. I hope this implementation achieves this/will achive this if #20516 is merged?Release Notes
JaxPrinter
, allowingsympy.lambdify
to operate on and produce JAX functions. Example:f = lambdify(variables, expr, 'jax')
.