-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
ENH: stats: add array API support for directional_stats
#20794
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Looks like |
I thought Have to take a deeper look into this |
EDIT: scipy/scipy/_lib/_array_api.py Lines 490 to 501 in b863eb9
|
For # maybe replace by `scipy.linalg` if/when converted
def xp_vector_norm(x, *, ..., xp=None):
xp = array_namespace(x) if xp is None else xp
# check for optional `linalg` extension
if hasattr(xp, 'linalg'):
return xp.linalg.vector_norm(...)
else:
# return (x @ x)**0.5
# or to get the right behavior with nd, complex arrays
return xp.sum(xp.conj(x) * x, axis=axis, keepdims=keepdims)**0.5 |
stats.directional_stats
directional_stats
This comment was marked as outdated.
This comment was marked as outdated.
@lucascolley I'm not sure I understand the need to check Suppose xp = array_namespace(x)
x = xp.asarray(x) # this is still common because we still accept lists
xp.linalg.vector_norm(x) Why is it different if In any case, I'd suggest implementing the vector norm in terms of standard operations instead of falling back to NumPy: if hasattr(xp, 'linalg'):
return xp.linalg.vector_norm(x, axis=axis, keepdims=keepdims)
else:
# return (x @ x)**0.5
# or to get the right behavior with nd, complex arrays
return xp.sum(xp.conj(x) * x, axis=axis, keepdims=keepdims)**0.5 Even NumPy isn't careful to avoid premature over/underflow, so I don't think we need to be more careful. And if others need different norms, they can add the |
@mdhaber you're right, I've updated my suggestion. |
Two small comments: 1/ In the test case 2/ Regarding the lack of |
7b2c1b8
to
ad9ab68
Compare
[skip cirrus] [skip circle] Co-authored-by: Matt Haberland <mhaberla@calpoly.edu>
[skip cirrus]
[skip cirrus]
[skip cirrus] [skip circle]
Some of the tests had |
scipy/stats/tests/test_morestats.py
Outdated
full_array = xp.asarray(np.tile(data, (2, 2, 2, 1))) | ||
expected = xp.asarray([[[1., 0., 0.], | ||
[1., 0., 0.]], | ||
[[1., 0., 0.], | ||
[1., 0., 0.]]], | ||
dtype=xp.float64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we perform this test with the default floating type ie
full_array = xp.asarray(np.tile(data, (2, 2, 2, 1))) | |
expected = xp.asarray([[[1., 0., 0.], | |
[1., 0., 0.]], | |
[[1., 0., 0.], | |
[1., 0., 0.]]], | |
dtype=xp.float64) | |
full_array = xp.asarray(np.tile(data, (2, 2, 2, 1)).tolist()) | |
expected = xp.asarray([[[1., 0., 0.], | |
[1., 0., 0.]], | |
[[1., 0., 0.], | |
[1., 0., 0.]]]) |
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm happy with either, personally. I think there is value in doing some or most tests with default floating point type, but there is also some value in testing with non-default types. If that happens naturally by convenience in some tests, great.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm only forcing the dtype of expected
. For default torch input (float32), directional_stats
returns float64
. Is that okay?
Same thing for the next test function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is that because of data-apis/array-api-compat#152? If not, what is converting it? vector_norm
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For default torch input (float32),
directional_stats
returnsfloat64
. Is that okay?
I don't think so. I think it's just NumPy that converts to float64
, and that's because of data-apis/array-api-compat#152 (and sum
is used). For torch
, the dtype is preserved as it should be, yet even if it is float32
, repr(f"{res.mean_resultant_length}")
prints lots of digits!
from array_api_compat import torch
from array_api_compat import numpy as np
from scipy import stats
rng = np.random.default_rng(2549824598234528)
x = np.astype(rng.random((3, 10)), np.float32)
y = torch.asarray(x)
assert y.dtype == torch.float32
res = stats.directional_stats(y)
assert res.mean_direction.dtype == torch.float32
assert res.mean_resultant_length.dtype == torch.float32
print(f"{res.mean_resultant_length}")
# 0.941432774066925
I guess that makes sense if Python converts it to a float
before printing. I don't think I've noticed that before, though; it's uncommon for the result class to have __repr__
defined explicitly.
I said before that it's OK to let sum
do its thing (return np.float64
with np.float32
input), but the problem is that it will change when data-apis/array-api-compat#152 is resolved. Might be better to explicitly provide the dtype
to sum
so it doesn't change later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What should we change. then? I assume you are referring to xp.sum
being called within our call to xp.mean
? We don't call xp.sum
directly for PyTorch or NumPy as it uses xp.linalg
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, then, not even NumPy converts to float64
. You wrote that "For default torch input (float32), directional_stats returns float64. Is that okay?" and I was trying to explain why that might be the case, but it doesn't look like that happens. I guess I didn't actually test for NumPy before, but it doesn't seem to happpen for NumPy either because - as you say - NumPy skips sum
, so data-apis/array-api-compat#152 doesn't come up.
I think you only have to force the dtype of expected
because data
starts out as a float64
array, so Torch's input is a float64
tensor. If we use xp.asarray
instead of np.array
, it should work without specifying the dtype.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not too much to say here. Thanks!
scipy/stats/tests/test_morestats.py
Outdated
full_array = xp.asarray(np.tile(data, (2, 2, 2, 1))) | ||
expected = xp.asarray([[[1., 0., 0.], | ||
[1., 0., 0.]], | ||
[[1., 0., 0.], | ||
[1., 0., 0.]]], | ||
dtype=xp.float64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm happy with either, personally. I think there is value in doing some or most tests with default floating point type, but there is also some value in testing with non-default types. If that happens naturally by convenience in some tests, great.
[skip cirrus] [skip circle]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please let me know how we can fix the regex failure in CI
scipy/stats/tests/test_morestats.py
Outdated
full_array = xp.asarray(np.tile(data, (2, 2, 2, 1))) | ||
expected = xp.asarray([[[1., 0., 0.], | ||
[1., 0., 0.]], | ||
[[1., 0., 0.], | ||
[1., 0., 0.]]], | ||
dtype=xp.float64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm only forcing the dtype of expected
. For default torch input (float32), directional_stats
returns float64
. Is that okay?
Same thing for the next test function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM after these changes!
[skip cirrus] [skip circle]
scipy/_lib/_array_api.py
Outdated
if SCIPY_ARRAY_API and hasattr(xp, 'linalg'): | ||
return xp.linalg.vector_norm(x, axis=axis, keepdims=keepdims, ord=ord) | ||
|
||
else: | ||
if ord != 2: | ||
raise ValueError( | ||
"only the Euclidean norm (`ord=2`) is currently supported in " | ||
"`xp_vector_norm` for backends not implementing the `linalg` extension." | ||
) | ||
# return (x @ x)**0.5 | ||
# or to get the right behavior with nd, complex arrays | ||
return xp.sum(xp.conj(x) * x, axis=axis, keepdims=keepdims)**0.5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't quite fully understand this. If SCIPY_ARRAY_API=0
then xp.sum(xp.conj(x) * x, axis=axis, keepdims=keepdims)**0.5
is always used which doesn't seem preferable compared to np.linalg.norm
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, I think I should rework the logic to
if SCIPY_ARRAY_API:
if hasattr(xp, linalg):
...
else:
...
else:
# use np.linalg.norm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be fixed now
[skip cirrus] [skip circle]
please squash merge! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks everyone!
Thanks everyone for pushing this in my absence. :) |
Co-authored-by: Lucas Colley <lucas.colley8@gmail.com> Co-authored-by: Matt Haberland <mhaberla@calpoly.edu> Co-authored-by: Jake Bowhay <60778417+j-bowhay@users.noreply.github.com>
Reference issue
Towards #20544
What does this implement/fix?
Adds array API support for
directional_stats
.Additional information
This is the first time I look into array API stuff, any pointers are highly appreciated. I mostly checked what is available in the standard and adapted existing tests to what I saw for tests of other array API compatible functions. I do not have pytorch locally installed, so mostly relying on CI now.