Skip to content

Commit

Permalink
Fix warning in .inv() of LinearOperators of type Matrix wrappin…
Browse files Browse the repository at this point in the history
…g `scipy.sparse.spmatrix` (#434)

* Convert sparse matrix to csc format before inversion

* Add test case for sparse linear operator

* Return correct exception for singular sparse linops

* I am dumb...

* Fix warnings in test code
  • Loading branch information
marvinpfoertner committed Jun 21, 2021
1 parent 7d3cf99 commit 2044c89
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/probnum/linops/_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,7 +904,7 @@ def __init__(
matmul = LinearOperator.broadcast_matmat(lambda x: self.A @ x)
rmatmul = LinearOperator.broadcast_rmatmat(lambda x: x @ self.A)
todense = self.A.toarray
inverse = lambda: Matrix(scipy.sparse.linalg.inv(self.A))
inverse = self._sparse_inv
trace = lambda: self.A.diagonal().sum()
else:
self.A = np.asarray(A)
Expand Down Expand Up @@ -953,6 +953,12 @@ def _astype(

return Matrix(A_astype)

def _sparse_inv(self) -> "Matrix":
try:
return Matrix(scipy.sparse.linalg.inv(self.A.tocsc()))
except RuntimeError as err:
raise np.linalg.LinAlgError(str(err)) from err


class Identity(LinearOperator):
"""The identity operator.
Expand Down
25 changes: 25 additions & 0 deletions tests/test_linops/test_linops_cases/linear_operator_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import pytest
import scipy.sparse

import probnum as pn

Expand Down Expand Up @@ -33,3 +34,27 @@ def case_matrix(matrix: np.ndarray) -> Tuple[pn.linops.LinearOperator, np.ndarra
@pytest.mark.parametrize("n", [3, 4, 8, 12, 15])
def case_identity(n: int) -> Tuple[pn.linops.LinearOperator, np.ndarray]:
return pn.linops.Identity(shape=n), np.eye(n)


@pytest.mark.parametrize("rng", [np.random.default_rng(42)])
def case_sparse_matrix(
rng: np.random.Generator,
) -> Tuple[pn.linops.LinearOperator, np.ndarray]:
matrix = scipy.sparse.rand(
10, 10, density=0.1, format="coo", dtype=np.double, random_state=rng
)
matrix.setdiag(2)
matrix = matrix.tocsr()

return pn.linops.Matrix(matrix), matrix.toarray()


@pytest.mark.parametrize("rng", [np.random.default_rng(42)])
def case_sparse_matrix_singular(
rng: np.random.Generator,
) -> Tuple[pn.linops.LinearOperator, np.ndarray]:
matrix = scipy.sparse.rand(
10, 10, density=0.01, format="csr", dtype=np.double, random_state=rng
)

return pn.linops.Matrix(matrix), matrix.toarray()

0 comments on commit 2044c89

Please sign in to comment.