Skip to content

Conversation

ev-br
Copy link
Member

@ev-br ev-br commented May 23, 2024

Towards gh-20678

What does this implement/fix?

Delegate signal.convolve and its ilk to cupyx.scipy.signal.convolve if inputs are cupy arrays. For other array types, do the usual convert to numpy, run, convert back dance.

Additional information

CuPy provides a near-complete clone of the scipy API in the cupyx.scipy namespace. We can treat these CuPy functions as accelerators: if a scipy function detects that its arguments are cupy-compatible, it can delegate all work to the cupyx function.

@ev-br ev-br requested review from larsoner and ilayn as code owners May 23, 2024 10:19
@github-actions github-actions bot added scipy.signal RFC Request for Comments; typically used to gather feedback for a substantial change proposal labels May 23, 2024
@lucascolley lucascolley changed the title RFC/POC: dispatch from scipy.submodule.func to cupyx.scipy.submodule.func RFC/POC: array types: dispatch to cupyx.scipy.submodule.func May 23, 2024
@lucascolley lucascolley added the array types Items related to array API support and input array validation (see gh-18286) label May 23, 2024
@lucascolley lucascolley changed the title RFC/POC: array types: dispatch to cupyx.scipy.submodule.func RFC/POC: array types, signal: dispatch to cupyx.scipy.submodule.func May 23, 2024
@ev-br
Copy link
Member Author

ev-br commented May 23, 2024

If you want to try extending this support to other libraries which have a scipy namespace ... JAX in addition right now

Thanks for the suggestion @lucascolley ! This is indeed a nice exercise to check how general it is. The last commit adds the dispatch. Curiously,

$ python dev.py test -t scipy/signal/tests/test_signaltools.py::TestConvolve -v -b jax.numpy -v
...
FAILED scipy/signal/tests/test_signaltools.py::TestConvolve::test_basic[jax.numpy] - AssertionError: dtypes do not match.
FAILED scipy/signal/tests/test_signaltools.py::TestConvolve::test_same[jax.numpy] - AssertionError: dtypes do not match.
FAILED scipy/signal/tests/test_signaltools.py::TestConvolve::test_same_eq[jax.numpy] - AssertionError: dtypes do not match.
FAILED scipy/signal/tests/test_signaltools.py::TestConvolve::test_2d_arrays[jax.numpy] - AssertionError: dtypes do not match.
FAILED scipy/signal/tests/test_signaltools.py::TestConvolve::test_valid_mode2[jax.numpy] - AssertionError: dtypes do not match.
FAILED scipy/signal/tests/test_signaltools.py::TestConvolve::test_same_mode[jax.numpy] - AssertionError: dtypes do not match.
FAILED scipy/signal/tests/test_signaltools.py::TestConvolve::test_convolve_method[jax.numpy] - AssertionError: 
FAILED scipy/signal/tests/test_signaltools.py::TestConvolve::test_convolve_method_large_input[jax.numpy] - AssertionError: dtypes do not match.
======================================================================================= 8 failed, 8 passed in 3.27s ============================

Am adding @skip_xp_backends for now --- going forward I guess it would be nice to have a matching xfail_xp_backends.

@lucascolley
Copy link
Member

I assume it's float32 vs float64? Which one is the result from convolve?

@ev-br ev-br force-pushed the sigtools_convolve_cupy branch from e8d894b to 1443958 Compare May 23, 2024 14:55
@ev-br
Copy link
Member Author

ev-br commented May 23, 2024

I assume it's float32 vs float64? Which one is the result from convolve?

There is at least one issue which looks like f32/f64 tolerance and there's a bunch of instances where jax returns float64 while cupyx and scipy preserve int64.

I'd be very wary of relaxing scipy tests on these. Either it's fixed upstream (if jax.scipy follows scipy), or we special-case it in tests (not in this PR I guess), or it just stays an xfail.

we should probably decide on what this variable should be called for consistency.

I'm not going to argue with whatever name you prefer. Mind pushing a tweak?

@lucascolley
Copy link
Member

jax returns float64 while cupyx and scipy preserve int64.

As long as that is deliberate from scipy, sounds like a bug in jax.

If you still believe that scipy_namespace_for is broken, let me know if we need changes in special!

@lucascolley lucascolley changed the title RFC/POC: array types, signal: dispatch to cupyx.scipy.submodule.func RFC/POC: array types, signal: dispatch to CuPy and JAX May 23, 2024
@lucascolley lucascolley removed request for ilayn and larsoner May 23, 2024 16:28
Copy link
Member

@lucascolley lucascolley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the CI failures, I suspect you need to restrict pytest.mark.usefixtures("skip_xp_backends") to functions/classes you have converted, rather than putting it in pytestmark, for now.

@lucascolley
Copy link
Member

lucascolley commented May 23, 2024

We should get CI green for NumPy first by being explicit about what dtype we expect (or using check_dtype=False if we don't care), then add a run of -t scipy/signal/tests/test_signaltools.py::TestConvolve to the Array API job

@ev-br ev-br requested a review from andyfaff as a code owner May 24, 2024 07:48
@lucascolley lucascolley removed the request for review from andyfaff May 24, 2024 07:51
@ev-br ev-br force-pushed the sigtools_convolve_cupy branch 2 times, most recently from 403ed97 to b9e4f12 Compare May 25, 2024 15:57
@ev-br
Copy link
Member Author

ev-br commented May 25, 2024

OK, reworked to mimic scipy.special: all delegation logic sits in a _support_alternative_backends.py file, and our original _signaltools.py is intact, as requested in #20772 (comment)

Have to admit I don't much like the sys.modules[func.__name__] = func dance (copycatted from scipy.special), but if it works it does I guess.
Overall this version does look a bit cleaner indeed.

One other simplification is about who is responsible for when an accelerator only implements a part of the scipy API. For instance, jax.scipy.signal.convolve2d only supports fillvalue=0. So the question is who is responsible for when a user feeds jax arrays and a nonzero fillvalue to scipy.signal.convolve2d? Previously this PR tried to intercept this case and raise a ValueError. This way, a user gets a traceback which points to scipy. The last commit removes this logic and makes {jax,cupy} arrays fall straight though to {jax.scipy,cupyx.scipy}.

Copy link
Member

@lucascolley lucascolley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mdhaber could you take a look at _support_alternative_backends.py as the author of the equivalent in special?

@ev-br ev-br force-pushed the sigtools_convolve_cupy branch 2 times, most recently from 1d17a32 to 6d825ee Compare November 3, 2024 18:36
@ev-br ev-br force-pushed the sigtools_convolve_cupy branch from 6d825ee to 2161976 Compare December 6, 2024 09:01
Copy link
Member

@j-bowhay j-bowhay left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a look over this, spotted a few things, and probably missed a bunch more! Maybe we should just merge this since since there's a decent period before the next release to iron out any issues

@lucascolley
Copy link
Member

Maybe we should just merge this

my term has just finished so I should have time to take a look at this

Copy link
Member

@lucascolley lucascolley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

approving the production code changes! (modulo a few observations)

I'll take a look at the test changes once the existing comments have been addressed

@lucascolley lucascolley added this to the 1.16.0 milestone Dec 8, 2024
@lucascolley lucascolley added the enhancement A new feature or improvement label Dec 8, 2024
@ev-br ev-br force-pushed the sigtools_convolve_cupy branch from 159a28a to 41d4c3c Compare December 9, 2024 07:27
@ev-br ev-br force-pushed the sigtools_convolve_cupy branch from 41d4c3c to 6d6543f Compare December 9, 2024 08:07
@ev-br
Copy link
Member Author

ev-br commented Dec 9, 2024

Addressed review comments.

Maybe we should just merge this since since there's a decent period before the next release to iron out any issues

Certainly would be happy to :-).

NB smoke-docs is going to fail left and right because of scipy/scipy_doctest#175, here and elsewhere.

Copy link
Member

@lucascolley lucascolley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall LGTM, thanks Evgeni & reviewers! Just a few comments

ev-br and others added 2 commits December 9, 2024 18:34
@ev-br
Copy link
Member Author

ev-br commented Dec 9, 2024

CI is green modulo two known and unrelated issues (mypy and smoke-docs), there are two approvals, merging as approved and CI-green. Thanks to all reviewers!

@ev-br ev-br merged commit 2c065e1 into scipy:main Dec 9, 2024
34 of 37 checks passed
@j-bowhay
Copy link
Member

j-bowhay commented Dec 9, 2024

Which of the prs stacked on this is the "next"?

@ev-br
Copy link
Member Author

ev-br commented Dec 9, 2024

#21713 is the closest "next", I'd say.
#21783 is largely independent.

After these two are in / are on their way in, will be able to move on to filters which accept windows, filter design and all that jazz.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
array types Items related to array API support and input array validation (see gh-18286) enhancement A new feature or improvement scipy.signal
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants