Skip to content

Inefficient outer indexing on sparse matrices when using indices of shape (N, 1) #11255

@jonaslb

Description

@jonaslb

A typical way to perform outer indexing can be unnecessarily and surprisingly super slow. It is illustrated in the in/out below.

I think the issue could be fixed by adding another fast path for input as in my In [8] to scipy._index.IndexMixin, in a similar manner to how it's done when using the method of In [10].

Reproducing code example:

In [1]: import scipy.sparse as ss, numpy as np                                                                                       

In [2]: matrix = ss.random(2000, 50000, format="csr")                                                                                

In [3]: idx0 = np.arange(1000)[:, None]                                                                                              

In [4]: idx1 = np.arange(1000, 2000)                                                                                                 

In [5]: imod = np.arange(25)[None, :]                                                                                                

In [6]: idx1 = (idx1[:, None] + 1000 * imod).T.reshape(-1, 1)                                                                        

In [8]: # The 'easy' way to index: also really slow 
   ...: %timeit out0 = matrix[idx0, idx1.T]                                                                                          
743 ms48 ms per loop (meanstd. dev. of 7 runs, 1 loop each)

In [9]: # Equivalent submatrix, but not so easy to write: Much faster 
   ...: %timeit out1 = matrix[idx0.ravel(), :][:, idx1.ravel()]                                                                      
3.57 ms89.6s per loop (meanstd. dev. of 7 runs, 100 loops each)

In [10]: # There is an existing fast path when the second index is 1D 
    ...: %timeit out2 = matrix[idx0, idx1.ravel()]                                                                                   
3.47 ms113s per loop (meanstd. dev. of 7 runs, 100 loops each)

Scipy/Numpy/Python version information:

In [40]: import sys, scipy, numpy; print(scipy.__version__, numpy.__version__, sys.version_info)                                     
1.3.0 1.16.4 sys.version_info(major=3, minor=7, micro=3, releaselevel='final', serial=0)

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions