Skip to content

da.compress fails when using cupy for condition array  #6169

@quasiben

Description

@quasiben

da.compress fails when using cupy for the condition array.

In [1]: import cupy as cp
import
In [2]: import dask.array as da

In [3]: rs = da.random.RandomState(RandomState=cp.random.RandomState)

In [4]: x = rs.randint(0, 3, size=(10, 10),
   ...:                chunks=(20, 5), dtype="uint8")

In [5]: da.compress(cp.asarray([True]), x, axis=0).compute()
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-5-89fea14f2e31> in <module>
----> 1 da.compress(cp.asarray([True]), x, axis=0).compute()

/datasets/bzaitlen/miniconda3/envs/20200501/lib/python3.7/site-packages/dask/array/routines.py in compress(condition, a, axis)
   1106
   1107     # Use `condition` to select along 1 dimension
-> 1108     a = a[tuple(condition if i == axis else slice(None) for i in range(a.ndim))]
   1109
   1110     return a

/datasets/bzaitlen/miniconda3/envs/20200501/lib/python3.7/site-packages/dask/array/core.py in __getitem__(self, index)
   1532         )
   1533
-> 1534         index2 = normalize_index(index, self.shape)
   1535
   1536         dependencies = {self.name}

/datasets/bzaitlen/miniconda3/envs/20200501/lib/python3.7/site-packages/dask/array/slicing.py in normalize_index(idx, shape)
    813     for i, d in zip(idx, none_shape):
    814         if d is not None:
--> 815             check_index(i, d)
    816     idx = tuple(map(sanitize_index, idx))
    817     idx = tuple(map(normalize_slice, idx, none_shape))

/datasets/bzaitlen/miniconda3/envs/20200501/lib/python3.7/site-packages/dask/array/slicing.py in check_index(ind, dimension)
    882     elif ind >= dimension:
    883         raise IndexError(
--> 884             "Index is not smaller than dimension %d >= %d" % (ind, dimension)
    885         )
    886

IndexError: Index is not smaller than dimension 1 >= 1

However, da.compress succeeds if condition array is a list/numpy array

In [6]: da.compress([True], x, axis=0).compute()
Out[6]: array([[2, 0, 2, 2, 2, 0, 1, 0, 1, 2]], dtype=uint8)

Fix is probably changing a bit of logic in the check_index function around array_like behavior:

elif isinstance(ind, (list, np.ndarray)):

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions