17
votes

I wish to speed up my machine learning algorithm (written in Python) using Numba (http://numba.pydata.org/). Note that this algorithm takes as its input data a sparse matrix. In my pure Python implementation, I used csr_matrix and related classes from Scipy, but apparently it is not compatible with Numba's JIT compiler.

I have also created my own custom class to implement the sparse matrix (which is basically a list of list of (index, value) pair), but again it is incompatible with Numba (i.e., I got some weird error message saying it doesn't recognize extension type)

Is there an alternative, simple way to implement sparse matrix using only numpy (without resorting to SciPy) that is compatible with Numba? Any example code would be appreciated. Thanks!

2
What features of csr_matrix did you use? You could try to reproduce their behavior in numpy, although I seriously doubt it would in general result in a speed-up...Jaime
I only use csr_matrix to store my data. What I need is simply to iterate row by row, and then for every row I want retrieve the list of indices and values. This is why for now I created my own class, implemented as a simple list of list. But again it's not recognized by Numba's compiler.rjo2909

2 Answers

4
votes

If all you have to do is iterate over the values of a CSR matrix, you can pass the attributes data, indptr, and indices to a function instead of the CSR matrix object.

from scipy import sparse
from numba import njit

@njit
def print_csr(A, iA, jA):
    for row in range(len(iA)-1):
        for i in range(iA[row], iA[row+1]):
            print(row, jA[i], A[i])

A = sparse.csr_matrix([[1, 2, 0], [0, 0, 3], [4, 0, 5]])
print_csr(A.data, A.indptr, A.indices)
3
votes

You can access the data of your sparse matrix as pure numpy or python. For example

M=sparse.csr_matrix([[1,0,0],[1,0,1],[1,1,1]])
ML = M.tolil()

for d,r in enumerate(zip(ML.data,ML.rows))
    # d,r are lists
    dr = np.array([d,r])
    print dr

produces:

[[1]
 [0]]
[[1 1]
 [0 2]]
[[1 1 1]
 [0 1 2]]

Surely numba can handle code that uses these arrays, provided, of course, that it does not expect each row to have the same size of array.


The lil format stores values 2 object dtype arrays, with data and indices stored lists, by row.