-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
ENH: stats: add array API support to some of _axis_nan_policy
decorator
#22857
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@@ -268,12 +271,12 @@ def _check_empty_inputs(samples, axis): | |||
return output | |||
|
|||
|
|||
def _add_reduced_axes(res, reduced_axes, keepdims): | |||
def _add_reduced_axes(res, reduced_axes, keepdims, xp=np): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes in this file are pretty straightforward xp-translations that allow non-numpy arrays to go through the first part of the decorator, which makes a few things much easier for wrapped functions.
- The dimensions specified by
axis=None
andaxis
tuples are raveled (so functions only have to consider scalar, inaxis
) - The remaining
axis
is moved to-1
, which tends to make indexing easier. keepdims
is supported automatically by adding back any axes that get reduced away.
It also ensures that when the input array is too small (e.g. empty), the right thing happens (e.g. emit SmallSampleWarning
, return NaN).
scipy/stats/_axis_nan_policy.py
Outdated
if not is_numpy(xp): | ||
res = hypotest_fun_out(*samples, axis=axis, **kwds) | ||
res = result_to_tuple(res, n_out) | ||
res = _add_reduced_axes(res, reduced_axes, keepdims, xp=xp) | ||
return tuple_to_result(*res) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are all pretty standard invocations; e.g., see 599-602.
|
||
|
||
# When `order` is array-like with size > 1, moment produces an *array* | ||
# rather than a tuple, but the zeroth dimension is to be treated like | ||
# separate outputs. It is important to make the distinction between | ||
# separate outputs when adding the reduced axes back (`keepdims=True`). | ||
def _moment_tuple(x, n_out): | ||
return tuple(x) if n_out > 1 else (x,) | ||
return tuple(x[i, ...] for i in range(x.shape[0])) if n_out > 1 else (x,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See code comment. For better or for worse, this is supposed to iterate over the zeroth axis, and we have to do that manually for array_api_strict
.
That might be helpful. Now that I'm in the "review lots of array API PRs mode" I'm inclined to keep going, but it's easier to find time to review and merge in one go if they're smaller diffs and more self-contained. This now has merge conflicts. Splitting it in two when updating may be nice. |
Will do. |
_axis_nan_policy
decorator
@@ -486,11 +486,12 @@ def _mode_result(mode, count): | |||
# When a slice is empty, `_axis_nan_policy` automatically produces | |||
# NaN for `mode` and `count`. This is a reasonable convention for `mode`, | |||
# but `count` should not be NaN; it should be zero. | |||
i = np.isnan(count) | |||
xp = array_namespace(mode, count) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These could be reverted
Reference issue
Toward gh-20544
What does this implement/fix?
Adds array API support to
scipy.stats.mode
.Update: yeah... I got a bit carried away, and now this also overhauls
_axis_nan_policy
to work with non-numpy arrays., sincemode
needed it for many of the tests. LMK if this needs to get broken up.Additional information
Dask complains about not being able to compute the chunk size. Haven't dealt with this before, so wisdom appreciated.
CuPy runs into data-apis/array-api-compat#312.
The shortcut when
a.ndim = 1
usednp.unique
, which treats NaNs as the same, but the standard says that NaNs are all different. There are a few options:unique_counts
despite all NaNs being treated as distinctxp.unique
for all backends that it's defined, otherwise fall back to the generic, n-d array calculationThoughts?