Skip to content

Conversation

lucascolley
Copy link
Member

@lucascolley lucascolley commented Sep 19, 2023

For context, this work was started as part of my internship at Quansight Labs, which ran until the end of September 2023.

Reference issue

Towards gh-19068 and gh-18867.
Please see gh-19068 for context.

What does this implement/fix?

Support is added for the functions in the array API standard linalg extension. This allows users to input arrays from any compatible array library.

Tests are modified to allow testing with numpy.array_api, cupy and torch. Some new tests are added for exceptions for unsupported parameters.

Additional information

I was not able to convert every relevant test since lots of NumPy-specifc things are used. We may want to find ways to convert some of these, or write new tests to serve the same purpose.

TestSVD is a little strange due to the way the lapack_driver parameter is tested. I have tried to apply a minimal refactor here, but a more substantial refactor may result in something more readable. It is a bit of a misnomer that all of our array API compatible tests are under TestSVD_GESDD, just because gesdd is the default value.

Lots of tests are failing for PyTorch CUDA, but hopefully we just need to wait for pytorch/pytorch#106773 to come through.

@lucascolley lucascolley marked this pull request as ready for review September 19, 2023 13:17
@lucascolley lucascolley changed the title WIP, ENH: linalg: support array API for standard extension functions ENH: linalg: support array API for standard extension functions Sep 19, 2023
@j-bowhay j-bowhay added enhancement A new feature or improvement scipy.linalg array types Items related to array API support and input array validation (see gh-18286) labels Sep 19, 2023
Copy link
Member

@ilayn ilayn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the regular function modifications are OK to me. But I'm not sure if we really need all the test changes. It doesn't reflect the goals of the tests at certain places and I'm not sure I understand why we test nonnumpy xp namespace in the tests. It should be pretty safe to test SciPy with default numpy arrays without any xp_assert_close's or modified tol parameters.

Copy link
Member

@rgommers rgommers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks pretty good to me overall. I'd like to see the diff shrink as much as possible though, both in the implementations and tests. That's in general a sign that the code is in good shape, and it makes things easier to review and understand later on.

I think the regular function modifications are OK to me. But I'm not sure if we really need all the test changes. It doesn't reflect the goals of the tests at certain places and I'm not sure I understand why we test nonnumpy xp namespace in the tests. It should be pretty safe to test SciPy with default numpy arrays without any xp_assert_close's or modified tol parameters.

I think it is quite useful to test non-numpy arrays; without testing it's almost certainly going to be broken. I think of these as testing optional dependencies - just like we have tests with mpmath, scikit-umfpack and a whole bunch of other optional runtime dependencies.

Sometimes this requires extra code, however it also tends to uncover bugs and non-standard code constructs that (when refactored) improve the code itself. Two examples here:

  1. The changes here from integer to floating point arrays for testing make sense. Functions like solve are inherently floating point-only. We also need to test that integers (and lists, and other array-like's) are still converted correctly and we don't break backwards compat. However, that can be a single small test. That many tests use integers is a matter of previous authors taking a shortcut because it didn't matter much, rather than all those tests using integers by design.
  2. The stricter dtype checks led me to spot a bug quickly when Lucas asked me about result_type:
>>> import numpy as np
>>> a = np.array([[3, 2, 0], [1, -1, 0], [0, 5, 1]])
>>> b = np.array([2, 4, -1])
>>> x = linalg.solve(a, b)
>>> a.dtype, b.dtype
(dtype('int64'), dtype('int64'))
>>> x.dtype
dtype('float64')

>>> # so output dtype should be float64 for integer input, but:
>>> linalg.solve(a, np.empty((3,0), dtype=np.int64)).dtype
dtype('int64')

We've seen in many places in cluster and fft as well that our tests don't check for expected dtypes, and we often have inconsistent return dtypes as a result (arguably all bugs).

So there is extra value from testing with other libraries: finding bugs and improving test coverage.

@ilayn
Copy link
Member

ilayn commented Sep 27, 2023

I think it is quite useful to test non-numpy arrays; without testing it's almost certainly going to be broken. I think of these as testing optional dependencies - just like we have tests with mpmath, scikit-umfpack and a whole bunch of other optional runtime dependencies.

Testing is always nice indeed but the question is what to do when it is broken. I think none of us want to go chasing around PyTorch or CuPy repos for fixing things that is not really meant for us to do just to get our tests out to the greenland.

@rgommers
Copy link
Member

Testing is always nice indeed but the question is what to do when it is broken. I think none of us want to go chasing around PyTorch or CuPy repos for fixing things that is not really meant for us to do just to get our tests out to the greenland.

I think the same of something breaks in NumPy, Cython, pytest, Sphinx or wherever else: we file an issue and skip the test or put a temporary upper bound. We're using pretty core/standard functions here, so I am not too worried about seeing too many regressions once things work. That would be really surprising. And in terms of debugging or even contributing upstream, I'd much rather work with CuPy or PyTorch than with things like pytest/sphinx/mpmath.

Also, the CuPy and PyTorch teams (and Dask and JAX) have invested large amounts of effort in NumPy and SciPy compatibility, so I'm pretty sure they'd appreciate and are willing to address bug reports.

Comment on lines +1066 to +1071
xp_assert_close(u.T @ u, xp.eye(3), atol=1e-6)
xp_assert_close(vh.T @ vh, xp.eye(3), atol=1e-6)
sigma = xp.zeros((u.shape[0], vh.shape[0]), dtype=s.dtype)
for i in range(s.shape[0]):
sigma[i, i] = s[i]
assert_array_almost_equal(u @ sigma @ vh, a)
xp_assert_close(u @ sigma @ vh, a, rtol=1e-6)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the tolerances around here could do with a look, I'm not sure why I wrote a mixture of atol and rtol.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah I remember, it's because the default for np.testing.assert_array_almost_equal is roughly equivalent to rtol=0, atol=1.5e-6.

@lucascolley
Copy link
Member Author

lucascolley commented Apr 1, 2024

CI should be green. I think this is almost ready. A few questions remain:

  • this PR involves a lot of general improvements to the tests (checking more dtypes, checking shapes, stricter tolerances), but clearly these improvements haven't been pushed to be optimal. I think doing so would be a huge effort given the size of this diff, but happy to work a little more if there are particular areas that could do with some TLC.
  • I don't know if we want somewhat of a policy about what to do with tolerances. I've basically just gone with the defaults of the assertions, but where identity matrices are used we get atol failures due to things being non-0, so I have introduced atol bumps where needed (the default atol is 0).
  • A lot of the tests which are currently skipped could be split up into parts which are compatible and parts which aren't (at least for this PR). I think it would be too much effort to split up all of them, but some are maybe worth it. I've marked a few with TODOs.

EDIT: spoke too soon on CI but looks like just a atol=0 thing so far

EDIT 2: finally green :)

@lucascolley lucascolley requested a review from ilayn April 1, 2024 23:09
@lucascolley lucascolley marked this pull request as ready for review April 1, 2024 23:11
[skip cirrus] [skip circle]
[skip cirrus] [skip circle]
[skip cirrus] [skip circle]
[skip cirrus] [skip circle]
Copy link
Contributor

@mdhaber mdhaber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just thought I'd come try to support this in case it would help @lucascolley. I'm also happy to stay out of it because it is easy for the kitchen to get too crowded.

If you'd like me to continue, one other question beforehand - were you interested in seeing if a mechanism like special's support_alternative_backends would work instead of modifying the individual functions? The code is mostly boilerplate, so I think it can be abstracted out. I'll be looking at whether the special approach can be used for signal later this week, and I could also throw this into the mix. The thought is that it could eliminate most of the diff in the functions themselves.

unsupported_args = {
'lower': lower,
'assume_a': assume_a != 'gen',
'transposed': transposed
Copy link
Contributor

@mdhaber mdhaber Jun 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we silently ignoring overwrite_a and overwrite_b since they don't guarantee anything about the behavior? Just checking understanding.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep.

@@ -141,11 +145,37 @@ def solve(a, b, lower=False, overwrite_a=False,
array([ True, True, True], dtype=bool)

"""
xp = array_namespace(a, b)
if check_finite:
Copy link
Contributor

@mdhaber mdhaber Jun 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose the idea here is that it's easy enough for us to support check_finite even though it's not in the standard?

Not that we need to, but it seems like it would be just as easy to support transposed. With slightly more effort, we could support lower with other options for assume_a. (It would not necessarily lead to more efficient solution; just a more consistent interface.)

Just want to make sure I'm understanding correctly - not suggesting any action.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, and yes. I was just implementing what was 'free' in this PR (just using a kwarg of _asarray here) and leaving the rest in xp_unsupported_args. Ideally, that list would be reduced as much as possible, which can happen in two ways:

  • us implementing agnostic versions of the args (which may be easy for some of them)
  • args being added to the standard API

}
if any(unsupported_args.values()):
xp_unsupported_args(unsupported_args)
if hasattr(xp, 'linalg'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible for xp to have linalg but linalg not to have solve?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but we don't want to support such a library. Ref: https://data-apis.org/array-api/latest/extensions/index.html

Extension module implementors must aim to provide all functions and other public objects in an extension. The rationale for this is that downstream usage can then check whether or not the extension is present (using hasattr(xp, 'extension_name') should be enough), and can then assume that functions are implemented. This in turn makes it also easy for array-consuming libraries to document which array libraries they support - e.g., “all libraries implementing the array API standard and its linear algebra extension”.

Comment on lines +163 to +165
dtype = xp.result_type(a, b, xp.float32)
a = xp.astype(a, dtype)
b = xp.astype(b, dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might want to do this dtype work even if you have to fall back to NumPy. The reason is that if a package didn't have linalg.solve and then adds it, the behavior could (conceivably) change if it starts doing dtype conversions before the arrays passed into the function.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that seems reasonable, I can't remember whether there was any reason for including this only behind this guard.

@lucascolley
Copy link
Member Author

were you interested in seeing if a mechanism like special's support_alternative_backends would work instead of modifying the individual functions? The code is mostly boilerplate, so I think it can be abstracted out. I'll be looking at whether the special approach can be used for signal later this week, and I could also throw this into the mix. The thought is that it could eliminate most of the diff in the functions themselves.

sure, feel free to have a go and we can compare notes. That should be a useful exercise.

Note that this is a rather unusual PR, deliberately including only boilerplate such that we can make use of the extension module. Most of these np-backend functions still have a lot of code that can be written in an agnostic way. I still want to convert as much of linalg as possible to be array-agnostic, and only use np where necessary, but that will come with future PRs.

@mdhaber
Copy link
Contributor

mdhaber commented Jun 15, 2024

Note that this is a rather unusual PR, deliberately including only boilerplate...

Do you mean only some parts of linalg (i.e. the functions you've treated here) will have direct correspondence with functions in other libraries , but others can be manually written in terms of other xp calls?

If so, special's dispatch mechanism supports that sort of thing, too, right? For instance, there is a backend-agnostic implementation of rel_entr if the backend does not have a function with that name.

@lucascolley
Copy link
Member Author

lucascolley commented Jun 15, 2024

I'm more so talking about functions which can be partially converted.

I suppose we need to figure out the worth (long-term) of such partial conversions. If we can make a function agnostic, apart from one LAPACK call, should we? Or should we just resign to conversion to NumPy for the entire function? Sure, converting to NumPy immediately is fine for now, but will we want it to stay that way as this stuff leaves the experimental stage?

I'm not sure on the answers to any of these questions, but the guidance from Pamphile's original PR was pretty clear that we only go to np when necessary. If we don't want to do that, we should be clear about whether that is temporary, or whether it isn't, in which case we should probably change the guidance.

The answer may well vary by submodule - if some code is destined to always need compiled code for example. There was some discussion along these lines above somewhere.


EDIT: apologies for the double-semantics of "conversion"

@mdhaber
Copy link
Contributor

mdhaber commented Jun 15, 2024

If we can make a function agnostic, apart from one LAPACK call, should we? Or should we just resign to conversion to NumPy for the entire function?

Personally, I think it's a matter of balance between performance and simplicity. For stats, we partially converted functions to compute the statistic with the array backend, then we computed the p-value with NumPy. This made sense because the statistic was a reducing operation, so there was (potentially) a lot to be gained even with partial conversion. (Of course, we took a second pass and converted the rest for completeness.)

In linalg, I'm imaginging that the one LAPACK call is going to be the expensive part, in which case I'd think that partial conversion would not be worth the effort, especially when wholesale conversion could be as simple as adding the function's name and number of arguments to a dictionary (e.g. by adding a special function's name and number of array arguments to a dictionary _support_alternative_backend, it uses the function from the backend if it's available and converts to and from NumPy otherwise.)

If we don't want to do that, we should be clear about whether that is temporary, or whether it isn't, in which case we should probably change the guidance.

Yeah, I'd second that. I hope we can also clarify the guidance from gh-18286 (#18286 (comment)):

For as-yet-unsupported GPU execution when hitting compiled code, we will raise exceptions. The alternative considered was to transfer to CPU, execute, and transfer back (e.g., for PyTorch). A pro of doing that would be that everything works, and there may still be performance gains. A con is that it silently does device transfers, usually not a good idea.

This suggests that device transfers are inherently bad, and that inherent badness would even outweigh everything working plus performance gain. So avoiding device transfers was a good part of my motivation for gh-20549, but that was not met enthusiastically. So under what conditions are device transfers OK, or what are the considerations besides performance?

EDIT: apologies for the double-semantics of "conversion"

Yes, I've thought maybe I should either refer to these array-API rewrites as "translations" or changing to NumPy arrays as "coercions". But I have not really decided for myself, and I am happy with the double-meaning of "conversion" in the meantime.

@lucascolley
Copy link
Member Author

Thanks Matt, agreed with your most recent comment. The goal is to abstract as much of the boilerplate away as possible (e.g.

def _execute_1D(func_str, pocketfft_func, x, n, axis, norm, overwrite_x, workers, plan):
xp = array_namespace(x)
if is_numpy(xp):
return pocketfft_func(x, n=n, axis=axis, norm=norm,
overwrite_x=overwrite_x, workers=workers, plan=plan)
norm = _validate_fft_args(workers, plan, norm)
if hasattr(xp, 'fft'):
xp_func = getattr(xp.fft, func_str)
return xp_func(x, n=n, axis=axis, norm=norm)
x = np.asarray(x)
y = pocketfft_func(x, n=n, axis=axis, norm=norm)
return xp.asarray(y)
def _execute_nD(func_str, pocketfft_func, x, s, axes, norm, overwrite_x, workers, plan):
xp = array_namespace(x)
if is_numpy(xp):
return pocketfft_func(x, s=s, axes=axes, norm=norm,
overwrite_x=overwrite_x, workers=workers, plan=plan)
norm = _validate_fft_args(workers, plan, norm)
if hasattr(xp, 'fft'):
xp_func = getattr(xp.fft, func_str)
return xp_func(x, s=s, axes=axes, norm=norm)
x = np.asarray(x)
y = pocketfft_func(x, s=s, axes=axes, norm=norm)
return xp.asarray(y)
), so I agree that if we can achieve the same via a _support_alternative_backends.py-esque method then that sounds better. I'll put this PR on hold while you work on that, if you would still like to.

@lucascolley
Copy link
Member Author

closing, this is still discoverable in the linked issue but is best taken forward in smaller chunks

@lucascolley lucascolley closed this Sep 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
array types Items related to array API support and input array validation (see gh-18286) enhancement A new feature or improvement needs-work Items that are pending response from the author scipy.linalg
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants