Skip to content

Commit

Permalink
Coverage for basic tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jcapriot committed Oct 10, 2024
1 parent 9350dba commit 9305f05
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 35 deletions.
36 changes: 26 additions & 10 deletions pymatsolver/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,11 @@ def __init__(
self._dtype = np.dtype(A.dtype)

if accuracy_tol is not None:
warnings.warn("accuracy_tol is deprecated and will be removed in v0.4.0, use check_rtol and check_atol.", FutureWarning, stacklevel=2)
warnings.warn(
"accuracy_tol is deprecated and will be removed in v0.4.0, use check_rtol and check_atol.",
FutureWarning,
stacklevel=3
)
check_rtol = accuracy_tol

self.check_accuracy = check_accuracy
Expand Down Expand Up @@ -249,7 +253,7 @@ def transpose(self):
if self._transpose_class is None:
raise NotImplementedError(
'The transpose for the {} class is not possible.'.format(
self.__name__
self.__class__.__name__
)
)
newS = self._transpose_class(self.A.T, **self.get_attributes())
Expand Down Expand Up @@ -307,7 +311,8 @@ def solve(self, rhs):
"In Future pymatsolver v0.4.0, passing a vector of shape (n, 1) to the solve method "
"will return an array with shape (n, 1), instead of always returning a flattened array. "
"This is to be consistent with numpy.linalg.solve broadcasting.",
FutureWarning
FutureWarning,
stacklevel=2
)
if rhs.shape[-2] != n:
raise ValueError(f'Second to last dimension should be {n}, got {rhs.shape}')
Expand Down Expand Up @@ -416,15 +421,15 @@ def __init__(self, A, check_accuracy=False, check_rtol=1e-6, check_atol=0, accur
try:
self._diagonal = np.asarray(A.diagonal())
if not np.all(self._diagonal):
# this works because 0 evaluates as False!
# this works because 0.0 evaluates as False!
raise ValueError("Diagonal matrix has a zero along the diagonal.")
except AttributeError:
raise TypeError("A must have a diagonal() method.")
kwargs.pop("is_symmetric", None)
kwargs.pop("is_hermitian", None)
is_hermitian = kwargs.pop("is_hermitian", None)
is_positive_definite = kwargs.pop("is_positive_definite", None)
super().__init__(
A, is_symmetric=True, is_hermitian=True, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, accuracy_tol=accuracy_tol, **kwargs
A, is_symmetric=True, is_hermitian=False, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, accuracy_tol=accuracy_tol, **kwargs
)
if is_positive_definite is None:
if self.is_real:
Expand All @@ -434,6 +439,14 @@ def __init__(self, A, check_accuracy=False, check_rtol=1e-6, check_atol=0, accur
is_positive_definite = bool(is_positive_definite)
self.is_positive_definite = is_positive_definite

if is_hermitian is None:
if self.is_real:
is_hermitian = True
else:
# can only be hermitian if all imaginary components on diagonal are zero.
is_hermitian = not np.any(self._diagonal.imag)
self.is_hermitian = is_hermitian

def _solve_single(self, rhs):
return rhs / self._diagonal

Expand Down Expand Up @@ -466,12 +479,15 @@ class TriangularSolver(Base):
"""

def __init__(self, A, lower=True, check_accuracy=False, check_rtol=1e-6, check_atol=0, accuracy_tol=None, **kwargs):
kwargs.pop("is_hermitian", False)
kwargs.pop("is_symmetric", False)
if not (sp.issparse(A) and A.format in ['csr','csc']):
# pop off unneeded keyword arguments.
is_hermitian = kwargs.pop("is_hermitian", False)
is_symmetric = kwargs.pop("is_symmetric", False)
is_positive_definite = kwargs.pop("is_positive_definite", False)
if not (sp.issparse(A) and A.format in ['csr', 'csc']):
A = sp.csc_matrix(A)
A.sum_duplicates()
super().__init__(A, is_hermitian=False, is_symmetric=False, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, accuracy_tol=accuracy_tol, **kwargs)
super().__init__(A, is_hermitian=is_hermitian, is_symmetric=is_symmetric, is_positive_definite=is_positive_definite, check_accuracy=check_accuracy, check_rtol=check_rtol, check_atol=check_atol, accuracy_tol=accuracy_tol, **kwargs)

self.lower = lower

@property
Expand Down
169 changes: 144 additions & 25 deletions tests/test_Basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,6 @@

TOL = 1e-12

def test_DiagonalSolver():

A = sp.identity(5)*2.0
rhs = np.c_[np.arange(1, 6), np.arange(2, 11, 2)]
X = Diagonal(A) * rhs
x = Diagonal(A) * rhs[:, 0]

sol = rhs/2.0

with pytest.raises(TypeError):
Diagonal(A, check_accuracy=np.array([1, 2, 3]))
with pytest.raises(ValueError):
Diagonal(A, accuracy_tol=0)

npt.assert_allclose(sol, X, atol=TOL)
npt.assert_allclose(sol[:, 0], x, atol=TOL)

class IdentitySolver(pymatsolver.solvers.Base):
""""A concrete implementation of Base, for testing purposes"""
def _solve_single(self, rhs):
Expand All @@ -32,34 +15,58 @@ def _solve_single(self, rhs):
def _solve_multiple(self, rhs):
return rhs

class NotTransposableIdentitySolver(IdentitySolver):
""" A class that can't be transposed."""

@property
def _transpose_class(self):
return None


def test_basics():

Ainv = IdentitySolver(np.eye(4))
assert Ainv.is_symmetric == True
assert Ainv.is_hermitian == True
assert Ainv.is_symmetric
assert Ainv.is_hermitian
assert Ainv.shape == (4, 4)
assert Ainv.is_real

Ainv = IdentitySolver(np.eye(4) + 0j)
assert Ainv.is_symmetric == True
assert Ainv.is_hermitian == True
assert Ainv.is_symmetric
assert Ainv.is_hermitian
assert not Ainv.is_real

Ainv = IdentitySolver(sp.eye(4))
assert Ainv.is_symmetric
assert Ainv.is_hermitian
assert Ainv.shape == (4, 4)

Ainv = IdentitySolver(sp.eye(4).astype(np.complex128))
assert Ainv.is_symmetric
assert Ainv.is_hermitian
assert Ainv.shape == (4, 4)
assert not Ainv.is_real

def test_basic_solve():
Ainv = IdentitySolver(np.eye(4))

rhs = np.arange(4)
rhs2d = np.arange(8).reshape(4, 2)
rhs3d = np.arange(16).reshape(2, 4, 2)
rhs3d = np.arange(24).reshape(3, 4, 2)

npt.assert_equal(Ainv @ rhs, rhs)
npt.assert_equal(Ainv @ rhs2d, rhs2d)
npt.assert_equal(Ainv @ rhs3d, rhs3d)

npt.assert_equal(rhs @ Ainv, rhs)
npt.assert_equal(rhs.T * Ainv, rhs)
npt.assert_equal(rhs * Ainv, rhs)

npt.assert_equal(rhs2d.T @ Ainv, rhs2d.T)
npt.assert_equal(rhs2d.T * Ainv, rhs2d.T)

npt.assert_equal(rhs3d.swapaxes(-1, -2) @ Ainv, rhs3d.swapaxes(-1, -2))
npt.assert_equal(rhs3d.swapaxes(-1, -2) * Ainv, rhs3d.swapaxes(-1, -2))

# use Diagonal solver as a concrete instance of the Base to test for some errors

def test_errors_and_warnings():

Expand Down Expand Up @@ -101,4 +108,116 @@ def test_errors_and_warnings():

with pytest.warns(FutureWarning, match="In Future pymatsolver v0.4.0, passing a vector.*"):
Ainv = IdentitySolver(np.eye(4, 4))
Ainv @ np.ones((4, 1))
Ainv @ np.ones((4, 1))

with pytest.raises(NotImplementedError, match="The transpose for the.*"):
Ainv = NotTransposableIdentitySolver(np.eye(4, 4), is_symmetric=False)
Ainv.T



def test_DiagonalSolver():

A = sp.identity(5)*2.0
rhs = np.c_[np.arange(1, 6), np.arange(2, 11, 2)]
X = Diagonal(A) * rhs
x = Diagonal(A) * rhs[:, 0]

sol = rhs/2.0

with pytest.raises(TypeError):
Diagonal(A, check_accuracy=np.array([1, 2, 3]))
with pytest.raises(ValueError):
Diagonal(A, check_rtol=0)

npt.assert_allclose(sol, X, atol=TOL)
npt.assert_allclose(sol[:, 0], x, atol=TOL)

def test_diagonal_errors():

with pytest.raises(TypeError, match="A must have a diagonal.*"):
Diagonal(
[
[2, 0],
[0, 1]
]
)

with pytest.raises(ValueError, match="Diagonal matrix has a zero along the diagonal."):
Diagonal(
np.array(
[
[0, 0],
[0, 1]
]
)
)

def test_diagonal_inferance():

Ainv = Diagonal(
np.array(
[
[2., 0.],
[0., 1.],
]
),
)

assert Ainv.is_symmetric
assert Ainv.is_positive_definite
assert Ainv.is_hermitian
assert Ainv.is_real

Ainv = Diagonal(
np.array(
[
[2.0, 0],
[0, -1.0],
]
),
)

assert Ainv.is_symmetric
assert not Ainv.is_positive_definite
assert Ainv.is_hermitian
assert Ainv.is_real

Ainv = Diagonal(
np.array(
[
[2 + 0j, 0],
[0, 2 + 0j],
]
)
)
assert not Ainv.is_real
assert Ainv.is_symmetric
assert Ainv.is_hermitian
assert Ainv.is_positive_definite

Ainv = Diagonal(
np.array(
[
[2 + 0j, 0],
[0, -2 + 0j],
]
)
)
assert not Ainv.is_real
assert Ainv.is_symmetric
assert Ainv.is_hermitian
assert not Ainv.is_positive_definite

Ainv = Diagonal(
np.array(
[
[2 + 1j, 0],
[0, 2 + 0j],
]
)
)
assert not Ainv.is_real
assert Ainv.is_symmetric
assert not Ainv.is_hermitian
assert not Ainv.is_positive_definite
16 changes: 16 additions & 0 deletions tests/test_Triangle.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,19 @@ def test_solve(solver):
Ainv = solver(A)
npt.assert_allclose(Ainv * rhs, sol, atol=TOL)
npt.assert_allclose(Ainv * rhs[:, 0], sol[:, 0], atol=TOL)


def test_triangle_errors():
A = sp.eye(5, format='csc')

with pytest.raises(TypeError, match="lower must be a bool."):
Ainv = pymatsolver.Forward(A)
Ainv.lower = 1


def test_mat_convert():
Ainv = pymatsolver.Forward(sp.eye(5, format='coo'))
x = np.arange(5)
npt.assert_allclose(Ainv @ x, x)


0 comments on commit 9305f05

Please sign in to comment.