-
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Closed
Labels
Description
da.compress
fails when using cupy for the condition array.
In [1]: import cupy as cp
import
In [2]: import dask.array as da
In [3]: rs = da.random.RandomState(RandomState=cp.random.RandomState)
In [4]: x = rs.randint(0, 3, size=(10, 10),
...: chunks=(20, 5), dtype="uint8")
In [5]: da.compress(cp.asarray([True]), x, axis=0).compute()
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-5-89fea14f2e31> in <module>
----> 1 da.compress(cp.asarray([True]), x, axis=0).compute()
/datasets/bzaitlen/miniconda3/envs/20200501/lib/python3.7/site-packages/dask/array/routines.py in compress(condition, a, axis)
1106
1107 # Use `condition` to select along 1 dimension
-> 1108 a = a[tuple(condition if i == axis else slice(None) for i in range(a.ndim))]
1109
1110 return a
/datasets/bzaitlen/miniconda3/envs/20200501/lib/python3.7/site-packages/dask/array/core.py in __getitem__(self, index)
1532 )
1533
-> 1534 index2 = normalize_index(index, self.shape)
1535
1536 dependencies = {self.name}
/datasets/bzaitlen/miniconda3/envs/20200501/lib/python3.7/site-packages/dask/array/slicing.py in normalize_index(idx, shape)
813 for i, d in zip(idx, none_shape):
814 if d is not None:
--> 815 check_index(i, d)
816 idx = tuple(map(sanitize_index, idx))
817 idx = tuple(map(normalize_slice, idx, none_shape))
/datasets/bzaitlen/miniconda3/envs/20200501/lib/python3.7/site-packages/dask/array/slicing.py in check_index(ind, dimension)
882 elif ind >= dimension:
883 raise IndexError(
--> 884 "Index is not smaller than dimension %d >= %d" % (ind, dimension)
885 )
886
IndexError: Index is not smaller than dimension 1 >= 1
However, da.compress
succeeds if condition array is a list/numpy array
In [6]: da.compress([True], x, axis=0).compute()
Out[6]: array([[2, 0, 2, 2, 2, 0, 1, 0, 1, 2]], dtype=uint8)
Fix is probably changing a bit of logic in the check_index function around array_like behavior:
Line 865 in 666b53a
elif isinstance(ind, (list, np.ndarray)): |