-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
ENH: array types: add dask.array
support
#20956
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This comment was marked as outdated.
This comment was marked as outdated.
bdd19fc
to
436cba5
Compare
@mdhaber FYI there are a few places in (I'll fix it here, but just in case you were going on to write something similar before this is merged) |
@ev-br could you suggest how to fix scipy/scipy/ndimage/_support_alternative_backends.py Lines 54 to 58 in 87ed0ca
for
? |
What is result and is type(result)? |
|
|
Then the fix is to wrap the |
cc @phofl (for awareness) |
@mdhaber does the diff in e63b914 look okay to you? I've been experimenting with Dask and a SciPy nightly, and this diff allows e.g. If so, I think I'd like to merge that separately before this PR. |
Just to follow up here @mdhaber - that change isn't specific to Dask. |
20dbdf8
to
2ea1c63
Compare
@ev-br do you have any idea why we are hitting "dask.array has no |
Looks like indeed some np.asarray / np.asanyarray calls are missing and internal functions receive mixtures of numpy and dask arrays:
|
Makes sense. I wonder why this doesn't trip up PyTorch CPU but does for Dask. |
Planning on helping give this a little nudge. Using the newly wrapped fft module in array-api-compat, we can get down to 11 failures on fft at least - mainly regarding dtype differences. I haven't looked at any of the other failures yet. EDIT 1: Have gotten special to work (other than the issue with suppressing the runtime warnings from dask/numpy) |
feel free to make a PR to my branch @lithomas1 ! |
I made a draft PR here Could you pull main on your branch later if you have more time? |
thanks @lithomas1 , that definitely helped to bring the number of errors down. Of the remaining errors, some look trivial, others I'm not so sure. |
# TODO: we shouldn't be mutating in-place here unless we make a copy | ||
# dask arrays do not copy before this somehow | ||
#a[i_max] = -xp.inf | ||
a = xp.where(i_max, -xp.asarray(xp.inf, dtype=a.dtype), a) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just flagging this as a review comment so it doesn't get lost
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A copy is always made, due to this line near the top of the logsumexp
implementation:
a, b = xp_broadcast_promote(a, b, ensure_writeable=True, ...
So this will yield correct semantics. It may be slower (didn't check) since a copy is always made now, while the previous implementation (e.g., in v1.14.1) only made a copy if np.any(b == 0)
is True.
Potentially, not copying and avoiding any in-place mutation but using xp.where
is therefore beneficial also for performance with numpy
in the common case where b
isn't used or doesn't have elements that are 0
.
@lithomas1 do you think you could resolve the merge conflicts? @ev-br also updated array-api-compat in gh-21796 |
Apologies for the long silence here. I did some digging into the remaining fft/special failures and other than some small issues with extra warnings coming from dask, I think I have those two modules passing locally. One source of failures that I've noticed is due to us passing the "naked" dask.array namespace (as opposed to the array-api-compat wrapped version) in as the The problem is worse for dask since oftentimes in the there is an This is probably why we had the issue with the copy not happening in |
import dask.array # type: ignore[import-not-found] | ||
xp_available_backends.update({'dask.array': dask.array}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import dask.array # type: ignore[import-not-found] | |
xp_available_backends.update({'dask.array': dask.array}) | |
import array_api_compat.dask.array as da # type: ignore[import-not-found] | |
xp_available_backends.update({'dask.array': da}) |
This is the change that I'm talking about that fixes a lot of things.
Assuming making this change looksgood, I'll cleanup previous unneeded fixes locally and send in another PR to add this + fix merge conflicts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't want to do this, actually. We want to pass arrays from the unwrapped namespace in tests to check that we correctly coerce with array_namespace
. If you need to use standard functions in tests, you can write xp_test = array_namespace(x)
. It's good to add a comment about what's missing from the unwrapped namespace.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, I'll try to add individual array_namespace
calls and report back on what the diff looks like after.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you grep for xp_test
you'll see examples of the pattern. There is another pattern like acos = array_namespace(x).acos
- either is fine.
if is_dask(xp): | ||
x.compute() | ||
y.compute() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think these should be assigned to variables (possibly the same variables they already occupy like x = x.compute()
)
Also the id
check at the end makes me wonder if this should use .persist()
instead of .compute()
(the previous point would still apply)
Now, I should have all of fft/linalg/special passing. stats looks do-able, the main changes seem to be silencing/ignoring the RuntimeWarnings from dask and figuring out how to make I've fixed up the merge conflicts in lucascolley#15, however the diff is very large, so I've split that out of my dask changes. (Maybe this is because I do merge commits instead of rebasing to keep up with main?) |
thanks @lithomas1, I rebased on main here. The CI log is quite messy due to an error in
|
Thanks for the update. I've attempted to cherry-pick my changes on top of your branch in lucascolley#16. (The history is still a little messed up - but I think it should be OK if you're fine with squashing my PR into a single commit). |
[skip cirrus] [skip circle]
[skip cirrus] [skip circle]
[skip cirrus] [skip circle]
[skip cirrus] [skip circle]
continued in #22240 |
Reference issue
Towards gh-18867
What does this implement/fix?
Test with
dask.array
via array-api-compat.Problems so far
Dask:
asarray
for array input withdtype
dask/dask#11288ndim>1
dask/dask#11398array-api-compat:
dask.array.clip
fails with missing attributebroadcast_shapes
data-apis/array-api-compat#176sort
orargsort
SciPy:
dask.array
support #20956 (comment)