-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
TST: fix GPU failures #21294
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TST: fix GPU failures #21294
Conversation
@mdhaber you wrote in #20957 (comment) that you think some modifications to the dev docs would be useful here. Could you clarify what you would like? |
and to add the It would be nice to update this with either:
|
Okay, I'll add an example using Of course, this can change if/when we start using
We discussed this in the community meeting and we should be able to look into GPU CI soon. |
@@ -178,7 +179,8 @@ def test_correlate01(self, xp): | |||
output = ndimage.convolve1d(array, weights) | |||
assert_array_almost_equal(output, expected) | |||
|
|||
@skip_xp_backends("jax.numpy", reasons=["output array is read-only."]) | |||
@skip_xp_backends("jax.numpy", reasons=["output array is read-only."], | |||
cpu_only=True, exceptions=['cupy', 'jax.numpy']) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How can jax.numpy
be both the first argument (the backend to skip) and in the exceptions
list?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Being the first argument means "always skip due to a reason in reasons
". Being in exceptions
means "don't get skipped by cpu_only
when not on CPU as delegation is implemented".
In this case, I just went with the blanket "delegation is implemented for CuPy and JAX in this module", even though the support for JAX is patchy. In its current state, JAX could be removed from exceptions
here. But if (somehow) JAX was no longer skipped via the first argument, we would want to add it back to exceptions
to avoid the cpu_only
skip.
Evidently this API is not perfectly intuitive, but I haven't thought of anything better yet. Please let me know if you do!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, that explanation seems reasonable. I can't think of something much cleaner. Explicitly explaining this case in the skip_xp_backends
docstring would be useful, to avoid future confusion to the extent possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a couple of comments/questions, but none are blocking. Overall this looks great, so +1 to merge soon from me.
Thanks Ralf, I'll give the docs an update then this is ready. Unless the test failures are annoying anyone locally, in which case feel free to merge immediately. |
b3e5329
to
45fb169
Compare
merged - clean-up and docs update inbound |
FYI I did put in that credit card request after the meeting and just received a reply on it - should arrive soon 🤞🏼 |
Reference issue
Closes gh-21227, closes gh-21292, closes gh-20957.
What does this implement/fix?
This PR makes
SCIPY_DEVICE=cuda python dev.py test -b all
pass for me locally, in an environment with CuPy, PyTorch and JAX.Additional information
Please take a look at any parts you wrote and let me know if the modifications look okay!
This diff is longer than it would be if we figured out how to "stack"
skip_xp_backends
decorators. But I haven't given that any thought recently, it seems difficult.