-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
ENH: linalg: support array API for standard extension functions #19260
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
53641f7
to
0d9337e
Compare
[skip cirrus] [skip circle]
0d9337e
to
25df2f8
Compare
[skip cirrus] [skip circle]
[skip cirrus] [skip circle]
[skip cirrus] [skip circle]
37239f2
to
ef9831d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the regular function modifications are OK to me. But I'm not sure if we really need all the test changes. It doesn't reflect the goals of the tests at certain places and I'm not sure I understand why we test nonnumpy xp namespace in the tests. It should be pretty safe to test SciPy with default numpy arrays without any xp_assert_close's or modified tol parameters.
… tols [skip cirrus] [skip circle]
[skip cirrus] [skip circle]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks pretty good to me overall. I'd like to see the diff shrink as much as possible though, both in the implementations and tests. That's in general a sign that the code is in good shape, and it makes things easier to review and understand later on.
I think the regular function modifications are OK to me. But I'm not sure if we really need all the test changes. It doesn't reflect the goals of the tests at certain places and I'm not sure I understand why we test nonnumpy xp namespace in the tests. It should be pretty safe to test SciPy with default numpy arrays without any xp_assert_close's or modified tol parameters.
I think it is quite useful to test non-numpy arrays; without testing it's almost certainly going to be broken. I think of these as testing optional dependencies - just like we have tests with mpmath
, scikit-umfpack
and a whole bunch of other optional runtime dependencies.
Sometimes this requires extra code, however it also tends to uncover bugs and non-standard code constructs that (when refactored) improve the code itself. Two examples here:
- The changes here from integer to floating point arrays for testing make sense. Functions like
solve
are inherently floating point-only. We also need to test that integers (and lists, and other array-like's) are still converted correctly and we don't break backwards compat. However, that can be a single small test. That many tests use integers is a matter of previous authors taking a shortcut because it didn't matter much, rather than all those tests using integers by design. - The stricter dtype checks led me to spot a bug quickly when Lucas asked me about
result_type
:
>>> import numpy as np
>>> a = np.array([[3, 2, 0], [1, -1, 0], [0, 5, 1]])
>>> b = np.array([2, 4, -1])
>>> x = linalg.solve(a, b)
>>> a.dtype, b.dtype
(dtype('int64'), dtype('int64'))
>>> x.dtype
dtype('float64')
>>> # so output dtype should be float64 for integer input, but:
>>> linalg.solve(a, np.empty((3,0), dtype=np.int64)).dtype
dtype('int64')
We've seen in many places in cluster
and fft
as well that our tests don't check for expected dtypes, and we often have inconsistent return dtypes as a result (arguably all bugs).
So there is extra value from testing with other libraries: finding bugs and improving test coverage.
Testing is always nice indeed but the question is what to do when it is broken. I think none of us want to go chasing around PyTorch or CuPy repos for fixing things that is not really meant for us to do just to get our tests out to the greenland. |
I think the same of something breaks in NumPy, Cython, pytest, Sphinx or wherever else: we file an issue and skip the test or put a temporary upper bound. We're using pretty core/standard functions here, so I am not too worried about seeing too many regressions once things work. That would be really surprising. And in terms of debugging or even contributing upstream, I'd much rather work with CuPy or PyTorch than with things like pytest/sphinx/mpmath. Also, the CuPy and PyTorch teams (and Dask and JAX) have invested large amounts of effort in NumPy and SciPy compatibility, so I'm pretty sure they'd appreciate and are willing to address bug reports. |
[skip ci]
[skip ci]
xp_assert_close(u.T @ u, xp.eye(3), atol=1e-6) | ||
xp_assert_close(vh.T @ vh, xp.eye(3), atol=1e-6) | ||
sigma = xp.zeros((u.shape[0], vh.shape[0]), dtype=s.dtype) | ||
for i in range(s.shape[0]): | ||
sigma[i, i] = s[i] | ||
assert_array_almost_equal(u @ sigma @ vh, a) | ||
xp_assert_close(u @ sigma @ vh, a, rtol=1e-6) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the tolerances around here could do with a look, I'm not sure why I wrote a mixture of atol
and rtol
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah I remember, it's because the default for np.testing.assert_array_almost_equal
is roughly equivalent to rtol=0, atol=1.5e-6
.
[skip ci]
[skip ci]
[skip ci]
[skip cirrus]
[skip cirrus]
CI should be green. I think this is almost ready. A few questions remain:
EDIT: spoke too soon on CI but looks like just a EDIT 2: finally green :) |
[skip cirrus] [skip circle]
[skip cirrus] [skip circle]
[skip cirrus] [skip circle]
[skip cirrus] [skip circle]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just thought I'd come try to support this in case it would help @lucascolley. I'm also happy to stay out of it because it is easy for the kitchen to get too crowded.
If you'd like me to continue, one other question beforehand - were you interested in seeing if a mechanism like special
's support_alternative_backends
would work instead of modifying the individual functions? The code is mostly boilerplate, so I think it can be abstracted out. I'll be looking at whether the special
approach can be used for signal
later this week, and I could also throw this into the mix. The thought is that it could eliminate most of the diff in the functions themselves.
unsupported_args = { | ||
'lower': lower, | ||
'assume_a': assume_a != 'gen', | ||
'transposed': transposed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we silently ignoring overwrite_a
and overwrite_b
since they don't guarantee anything about the behavior? Just checking understanding.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep.
@@ -141,11 +145,37 @@ def solve(a, b, lower=False, overwrite_a=False, | |||
array([ True, True, True], dtype=bool) | |||
|
|||
""" | |||
xp = array_namespace(a, b) | |||
if check_finite: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose the idea here is that it's easy enough for us to support check_finite
even though it's not in the standard?
Not that we need to, but it seems like it would be just as easy to support transposed
. With slightly more effort, we could support lower
with other options for assume_a
. (It would not necessarily lead to more efficient solution; just a more consistent interface.)
Just want to make sure I'm understanding correctly - not suggesting any action.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, and yes. I was just implementing what was 'free' in this PR (just using a kwarg of _asarray
here) and leaving the rest in xp_unsupported_args
. Ideally, that list would be reduced as much as possible, which can happen in two ways:
- us implementing agnostic versions of the args (which may be easy for some of them)
- args being added to the standard API
} | ||
if any(unsupported_args.values()): | ||
xp_unsupported_args(unsupported_args) | ||
if hasattr(xp, 'linalg'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible for xp
to have linalg
but linalg
not to have solve
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but we don't want to support such a library. Ref: https://data-apis.org/array-api/latest/extensions/index.html
Extension module implementors must aim to provide all functions and other public objects in an extension. The rationale for this is that downstream usage can then check whether or not the extension is present (using hasattr(xp, 'extension_name') should be enough), and can then assume that functions are implemented. This in turn makes it also easy for array-consuming libraries to document which array libraries they support - e.g., “all libraries implementing the array API standard and its linear algebra extension”.
dtype = xp.result_type(a, b, xp.float32) | ||
a = xp.astype(a, dtype) | ||
b = xp.astype(b, dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You might want to do this dtype work even if you have to fall back to NumPy. The reason is that if a package didn't have linalg.solve
and then adds it, the behavior could (conceivably) change if it starts doing dtype conversions before the arrays passed into the function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that seems reasonable, I can't remember whether there was any reason for including this only behind this guard.
sure, feel free to have a go and we can compare notes. That should be a useful exercise. Note that this is a rather unusual PR, deliberately including only boilerplate such that we can make use of the extension module. Most of these |
Do you mean only some parts of If so,
|
I'm more so talking about functions which can be partially converted. I suppose we need to figure out the worth (long-term) of such partial conversions. If we can make a function agnostic, apart from one LAPACK call, should we? Or should we just resign to conversion to NumPy for the entire function? Sure, converting to NumPy immediately is fine for now, but will we want it to stay that way as this stuff leaves the experimental stage? I'm not sure on the answers to any of these questions, but the guidance from Pamphile's original PR was pretty clear that we only go to The answer may well vary by submodule - if some code is destined to always need compiled code for example. There was some discussion along these lines above somewhere. EDIT: apologies for the double-semantics of "conversion" |
Personally, I think it's a matter of balance between performance and simplicity. For In
Yeah, I'd second that. I hope we can also clarify the guidance from gh-18286 (#18286 (comment)):
This suggests that device transfers are inherently bad, and that inherent badness would even outweigh everything working plus performance gain. So avoiding device transfers was a good part of my motivation for gh-20549, but that was not met enthusiastically. So under what conditions are device transfers OK, or what are the considerations besides performance?
Yes, I've thought maybe I should either refer to these array-API rewrites as "translations" or changing to NumPy arrays as "coercions". But I have not really decided for myself, and I am happy with the double-meaning of "conversion" in the meantime. |
Thanks Matt, agreed with your most recent comment. The goal is to abstract as much of the boilerplate away as possible (e.g. scipy/scipy/fft/_basic_backend.py Lines 24 to 55 in c22b657
_support_alternative_backends.py -esque method then that sounds better. I'll put this PR on hold while you work on that, if you would still like to.
|
closing, this is still discoverable in the linked issue but is best taken forward in smaller chunks |
For context, this work was started as part of my internship at Quansight Labs, which ran until the end of September 2023.
Reference issue
Towards gh-19068 and gh-18867.
Please see gh-19068 for context.
What does this implement/fix?
Support is added for the functions in the array API standard
linalg
extension. This allows users to input arrays from any compatible array library.Tests are modified to allow testing with
numpy.array_api
,cupy
andtorch
. Some new tests are added for exceptions for unsupported parameters.Additional information
I was not able to convert every relevant test since lots of NumPy-specifc things are used. We may want to find ways to convert some of these, or write new tests to serve the same purpose.
TestSVD
is a little strange due to the way thelapack_driver
parameter is tested. I have tried to apply a minimal refactor here, but a more substantial refactor may result in something more readable. It is a bit of a misnomer that all of our array API compatible tests are underTestSVD_GESDD
, just becausegesdd
is the default value.Lots of tests are failing for PyTorch CUDA, but hopefully we just need to wait for pytorch/pytorch#106773 to come through.