Skip to content

Conversation

tttc3
Copy link
Contributor

@tttc3 tttc3 commented Jun 14, 2022

References to other Issues or PRs

Adds JAX and parametric printing support as discussed in #20516
Parametric support has been moved to #23661.

Brief description of what is fixed or changed

Allows lambdify to output JAX compatible functions. and provides parametric support for all children of AbstractPythonCodePrinter.

The parametric option attempts to provide native support for the problems that @patrick-kidger's sympytorch and @MilesCranmer sympy2jax packages solve. I hope this implementation achieves this/will achive this if #20516 is merged?

Release Notes

  • printing
    • Added JaxPrinter, allowing sympy.lambdify to operate on and produce JAX functions. Example: f = lambdify(variables, expr, 'jax').

@sympy-bot
Copy link

sympy-bot commented Jun 14, 2022

Hi, I am the SymPy bot (v167). I'm here to help you write a release notes entry. Please read the guide on how to write release notes.

Your release notes are in good order.

Here is what the release notes will look like:

  • printing
    • Added JaxPrinter, allowing sympy.lambdify to operate on and produce JAX functions. Example: f = lambdify(variables, expr, 'jax'). (#23627 by @tttc3)

This will be added to https://github.com/sympy/sympy/wiki/Release-Notes-for-1.11.

Click here to see the pull request description that was parsed.
#### References to other Issues or PRs
Adds JAX ~~and parametric printing support~~ as discussed in #20516 
Parametric support has been moved to #23661.

#### Brief description of what is fixed or changed
Allows `lambdify` to output JAX compatible functions. ~~and provides parametric support for all children of `AbstractPythonCodePrinter`.~~

~~The parametric option attempts to provide native support for the problems that @patrick-kidger's `sympytorch` and @milescranmer `sympy2jax` packages solve. I hope this implementation achieves this/will achive this if #20516 is merged?~~

#### Release Notes
<!-- BEGIN RELEASE NOTES -->
* printing
  * Added `JaxPrinter`, allowing `sympy.lambdify` to operate on and produce JAX functions. Example: `f = lambdify(variables, expr, 'jax')`.
<!-- END RELEASE NOTES -->

Update

The release notes on the wiki have been updated.

@sympy-bot
Copy link

sympy-bot commented Jun 14, 2022

🟠

Hi, I am the SymPy bot (v167). I've noticed that some of your commits add or delete files. Since this is sometimes done unintentionally, I wanted to alert you about it.

This is an experimental feature of SymPy Bot. If you have any feedback on it, please comment at sympy/sympy-bot#75.

The following commits add new files:

  • 675c8da:
    • sympy/printing/tests/test_jax.py

If these files were added/deleted on purpose, you can ignore this message.

Copy link
Member

@bjodah bjodah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this contribution, I had a quick look and it looks really nice.

The Jax printer looks good, the parametric part I'm less sure about: isn't this approach fragile to e.g. reordering of the arguments? What is the benefit over simply preprocessing the expression and introducing parameter symbols manually?

Would it perhaps be possible to split the PR into two parts, first the Jax printer, and then a second PR for the parametric part?

Let's see what the others say.

@moorepants
Copy link
Member

I also don't understand the parametric part. If you internally replace all numeric nodes with symbols, how would the user know how to provide values in the resulting numeric function in order and consistently. For example, what if there are two 1's in the expression? Are you replacing those with the same symbol or two different symbols and, if two, what if the user then provides different values for them? Maybe a description of the use case for parametric would help clear it up.

@moorepants
Copy link
Member

Would it perhaps be possible to split the PR into two parts, first the Jax printer, and then a second PR for the parametric part?

This sounds like a good idea.

@tttc3 tttc3 changed the title Added JaxPrinter and Parametric printing Added JaxPrinter ~~and Parametric printing~~ Jun 17, 2022
@tttc3 tttc3 changed the title Added JaxPrinter ~~and Parametric printing~~ Added JaxPrinter Jun 17, 2022
@github-actions
Copy link

github-actions bot commented Jun 17, 2022

Benchmark results from GitHub Actions

Lower numbers are good, higher numbers are bad. A ratio less than 1
means a speed up and greater than 1 means a slowdown. Green lines
beginning with + are slowdowns (the PR is slower then master or
master is slower than the previous release). Red lines beginning
with - are speedups.

Significantly changed benchmark results (PR vs master)

Significantly changed benchmark results (master vs previous release)

       before           after         ratio
     [77f1d79c]       [a28a214e]
     <sympy-1.10.1^0>                 
+       100±0.8ms          188±3ms     1.87  sum.TimeSum.time_doit

Full benchmark results can be found as artifacts in GitHub Actions
(click on checks at the top of the PR).

@oscargus oscargus added this to the SymPy 1.11 milestone Jun 19, 2022
@oscarbenjamin
Copy link
Collaborator

Is the parametric part in a separate PR now?

Now that the parametric part has been removed from this PR can someone review this?

It looks like a good addition for 1.11.

@oscarbenjamin
Copy link
Collaborator

Also there's a simple merge conflict.

@tttc3 tttc3 mentioned this pull request Jun 21, 2022
@tttc3
Copy link
Contributor Author

tttc3 commented Jun 21, 2022

Is the parametric part in a separate PR now?

Now that the parametric part has been removed from this PR can someone review this?

It looks like a good addition for 1.11.

The parameteric pull has been moved to #23661

Copy link
Member

@bjodah bjodah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made a second pass of reviewing. Some minor comments/questions.

from sympy.core.symbol import symbols
n = symbols('n', integer=True)
N = MatrixSymbol("M", n, n)
raises(NotImplementedError, lambda: lambdify(N, N + Identity(n)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is JAX involved here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test has been copied over from the NumpyPrinter tests. I must have forgot to change the lambdify expression to call the JaxPrinter instead. I will update the tests to rectify this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated in commit 4b3869b

@oscarbenjamin oscarbenjamin removed this from the SymPy 1.11 milestone Jun 27, 2022
@oscarbenjamin
Copy link
Collaborator

I'm removing the 1.11 milestone. It would be good to get this in to the 1.11 release and if it is merged soon then it will be. Otherwise there's no particular reason for this to be a release blocker.

Copy link
Member

@bjodah bjodah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks ready from my point of view.
Thank you @tttc3 for this contribution.

@bjodah bjodah merged commit dff24a1 into sympy:master Jun 28, 2022
@patrick-kidger
Copy link

Folks over here may be interested in sympy2jax, which also aims to do the sympy->JAX conversion (and back again). I think the main differences to over here are (a) parameter handling and (b) there's no printing; it's just a normal JAX construct.

@asmeurer
Copy link
Member

asmeurer commented Jul 6, 2022

Thanks @patrick-kidger. Would it make sense to include sympy2jax in SymPy itself? Would you say it's strictly better than what's included here? I can imagine benefits to the printing approach (like being easier to debug), but I suspect sympy2jax approach is better in most cases.

@patrick-kidger
Copy link

IMO it is strictly better (I made it for a reason after all) but obviously I'm not an unbiased observer here.

I'm don't think it makes sense to bake into SymPy itself. It make an opionated choice of library to build on: Equinox as opposed to Flax or Haiku. And longer-term I don't think SymPy should need to implement a SymPy-to-Foo for those Foo that have yet to be released. Delegating to to some external library has worked so far for PyTorch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants