-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
ENH: ndimage: delegate to CuPy #21091
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
21290a1
to
eaa9a0a
Compare
@@ -107,25 +129,25 @@ def _validate_complex(self, array, kernel, type2, mode='reflect', cval=0): | |||
# test correlate output dtype | |||
output = correlate(array, kernel, output=type2) | |||
assert_array_almost_equal(expected, output) | |||
assert_equal(output.dtype.type, type2) | |||
assert output.dtype.type == type2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, the dtypes passed in here will always be NumPy dtypes, but this takes advantage of the fact that CuPy dtypes can be compared with NumPy dtypes and this PR is only for CuPy backend?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah. CuPy dtypes in fact are numpy dtypes, so this works for cupy. Won't work for dtypes from array-api-strict though, so if we want to make it work in this PR, the diff is going to grow by another ~500 mechanical changes :-).
|
||
# test correlate with pre-allocated output | ||
output = np.zeros_like(array, dtype=type2) | ||
output = xp.zeros_like(array, dtype=type2) | ||
correlate(array, kernel, output=output) | ||
assert_array_almost_equal(expected, output) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does assert_array_almost_equal
work for CuPy arrays?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No. This is a shim from _lib.
speaking of dtypes: what is the recommended way to to declare a dtype across all array api compatible backends? The pattern is
and it fails to understand the dtype under |
07c51a0
to
fd28566
Compare
(However, since we are dealing with the raw namespaces rather than the array-api-compat wrapped versions initially in tests, this is what is blocking us from adding dask, as there is no |
In the test itself yes, but how to |
Oops 🤦♂️ try |
No dice: diff --git a/scipy/ndimage/tests/test_interpolation.py b/scipy/ndimage/tests/test_interpolation.py
index db3479dbb1..63e816d7d8 100644
--- a/scipy/ndimage/tests/test_interpolation.py
+++ b/scipy/ndimage/tests/test_interpolation.py
@@ -194,11 +194,12 @@ class TestNdimageInterpolation:
assert_array_almost_equal(out, [0, 4, 1, 3])
@pytest.mark.parametrize('order', range(0, 6))
- @pytest.mark.parametrize('dtype', [np.float64, np.complex128])
+ @pytest.mark.parametrize('dtype', ["float64", "complex128"])
def test_geometric_transform05(self, order, dtype, xp):
if is_cupy(xp):
pytest.xfail("CuPy does not have geometric_transform")
+ dtype = getattr(xp, "dtype")
data = xp.asarray([[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1]], dtype=dtype) produces
|
I think there is a type in your patch, should it not be |
Ah, that's it. Thanks Jake! |
0761236
to
9b013f4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've finished one of the big test files - there are a lot of occurrences of indentation being thrown off by adding xp.asarray
. A few comments otherwise.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review comments addressed.
4fd7284
to
2a6944f
Compare
OK, I hopefully reformatted all instances of misaligned arrays. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getting there! this is a huge diff 😅
This is all going to be mechanical, they said :-). |
1bae880
to
021addf
Compare
All right, all the test formatting etc should be in shape now. |
44406fd
to
705d9c9
Compare
The last round of comments is addressed. |
05c5e55
to
37655f9
Compare
FYI @lucascolley turning other backends on requires further test tweaks. I'll sync tests here when #21150 reaches a natural sync point. |
37655f9
to
f4e5d46
Compare
All tests are in sync now. Sorry for the back-and-forth @lucascolley |
could you take a look at the CI failures here? I assume most of it is adding the appropriate skips, to be removed in the next PR. |
Ha, this one is a funny one. TL;DR: Consider
The decorator clears other skips, so the test is only skipped on jax --- and fails miserably on torch and array-api-strict. <stepping away from the keyboard for 1/2 hour> It seems to me that the attempt to break a large PR into a "stack" of this one and #21150, which works great on some other projects, simply does not work in SciPy. Or at least does not work here anyway. I'm open to suggestions on how to move this forward. Preferably in a way that won't require me adding and immediately removing ~100 skips. |
Would you like to just consolidate into 1 PR? The diff is so big anyway that making it a bit bigger doesn't really make it any more difficult to review. RE: multi-level skips - yes, this has caught me out before. I haven't rushed to work around this, though, as I think the mental model of 'the source of truth for skips is the closest decorator to the function's level, if one exists' is okay to grasp, once you've been told about it. It just isn't the most intuitive model. |
OK, let's continue in #21150 indeed. The bulk of the diff is just the same:
The main footgun IME is that repeated skips tend to be fall into cracks: if you skip X in |
Reference issue
cross-ref #21012, #20772
What does this implement/fix?
This is a POC of the layer of dispatch/delegation from
scipy.ndimage
tocupyx.scipy.ndimage
, for a full scipy submodule.ndimage
is probably clearer thansignal
in that it's all compiled code.Several things to note:
xp_assert
framework, add xfails for CuPy issues_support_alternative_backends.py
is for now copy-pasted from ENH: array types, signal: delegate to CuPy and JAX for correlations and convolutions #20772; the next step would be to deduplicate with the machinery from ENH:linalg
/signal
/special
: unify approach for array API dispatching #21012; I think some simplifications are possible, too.linalg
/signal
/special
: unify approach for array API dispatching #21012 is that instead of a mapping dictionary here we use a dedicated private module,_dispatchers.py
._ndimage_api.py
collects imports from implementer modules instead of__init__.py
_supports_alternative_backends.py
decorates things from_ndimage_api
with the cupy dispatch__init__.py
imports decorated names from_supports_alternative_backends.py
_dispatcher
functions. The full set is worked out for the whole submodule, so we can assess if we can simplify and possibly avoid dispatching on a full signature (I don't think we can, but am open to pleasant surprises).Technical TODOs:
map_coordinates
?dev.py -b all
Additional information