From f9f96ce472059192f58893b8d7add35b41d1515c Mon Sep 17 00:00:00 2001 From: Joseph Capriotti Date: Thu, 26 Sep 2024 14:38:32 -0600 Subject: [PATCH] fix pardiso tests --- tests/test_Pardiso.py | 41 ++++++++++++++++++----------------------- 1 file changed, 18 insertions(+), 23 deletions(-) diff --git a/tests/test_Pardiso.py b/tests/test_Pardiso.py index 9b2f8da..cd6bfff 100644 --- a/tests/test_Pardiso.py +++ b/tests/test_Pardiso.py @@ -1,11 +1,5 @@ +import pydiso.mkl_solver import pymatsolver -try: - from pydiso.mkl_solver import ( - get_mkl_pardiso_max_threads, - PardisoTypeConversionWarning - ) -except ImportError: - Pardiso = None import numpy as np import pytest import scipy.sparse as sp @@ -14,7 +8,7 @@ if not pymatsolver.AvailableSolvers['Pardiso']: pytest.skip(reason="Pardiso solver is not installed", allow_module_level=True) else: - from pymatsolver.direct.pardiso import get_mkl_pardiso_max_threads + from pydiso.mkl_solver import PardisoTypeConversionWarning, get_mkl_pardiso_max_threads TOL = 1e-10 @@ -30,24 +24,25 @@ def test_mat_data(): rhs = A.dot(sol) return A, rhs, sol -@pytest.fixture('transpose', [True, False]) -@pytest.fixture('dtype', [np.float64, np.complex128]) +@pytest.mark.parametrize('transpose', [True, False]) +@pytest.mark.parametrize('dtype', [np.float64, np.complex128]) def test_solve(test_mat_data, dtype, transpose): A, rhs, sol = test_mat_data sol = sol.astype(dtype) rhs = rhs.astype(dtype) A = A.astype(dtype) if transpose: - Ainv = Pardiso(A.T).T + with pytest.warns(PardisoTypeConversionWarning): + Ainv = pymatsolver.Pardiso(A.T).T else: - Ainv = Pardiso(A) + Ainv = pymatsolver.Pardiso(A) for i in range(rhs.shape[1]): np.testing.assert_allclose(Ainv * rhs[:, i], sol[:, i], atol=TOL) np.testing.assert_allclose(Ainv * rhs, sol, atol=TOL) def test_symmetric_solve(test_mat_data): A, rhs, sol = test_mat_data - Ainv = Pardiso(A, is_symmetric=True) + Ainv = pymatsolver.Pardiso(A, is_symmetric=True) for i in range(rhs.shape[1]): np.testing.assert_allclose(Ainv * rhs[:, i], sol[:, i], atol=TOL) np.testing.assert_allclose(Ainv * rhs, sol, atol=TOL) @@ -55,7 +50,7 @@ def test_symmetric_solve(test_mat_data): def test_refactor(test_mat_data): A, rhs, sol = test_mat_data - Ainv = Pardiso(A, is_symmetric=True) + Ainv = pymatsolver.Pardiso(A, is_symmetric=True) np.testing.assert_allclose(Ainv * rhs, sol, atol=TOL) # scale rows and columns @@ -71,10 +66,10 @@ def test_n_threads(test_mat_data): max_threads = get_mkl_pardiso_max_threads() print(f'testing setting n_threads to 1 and {max_threads}') - Ainv = Pardiso(A, is_symmetric=True, n_threads=1) + Ainv = pymatsolver.Pardiso(A, is_symmetric=True, n_threads=1) assert Ainv.n_threads == 1 - Ainv2 = Pardiso(A, is_symmetric=True, n_threads=max_threads) + Ainv2 = pymatsolver.Pardiso(A, is_symmetric=True, n_threads=max_threads) assert Ainv2.n_threads == max_threads # the n_threads setting is global so setting Ainv2's n_threads will @@ -100,20 +95,19 @@ def test_n_threads(test_mat_data): # def test(self): # rhs = self.rhs # sol = self.sol -# Ainv = Pardiso(self.A, is_symmetric=True, check_accuracy=True) +# Ainv = pymatsolver.Pardiso(self.A, is_symmetric=True, check_accuracy=True) # with pytest.raises(Exception): # Ainv * rhs # Ainv.clean() # -# Ainv = Pardiso(self.A) +# Ainv = pymatsolver.Pardiso(self.A) # for i in range(3): # assert np.linalg.norm(Ainv * rhs[:, i] - sol[:, i]) < TOL # assert np.linalg.norm(Ainv * rhs - sol, np.inf) < TOL # Ainv.clean() -def test_pardiso_fdem: - +def test_pardiso_fdem(): base_path = os.path.join(os.path.split(os.path.abspath(__file__))[0], 'fdem') data = np.load(os.path.join(base_path, 'A_data.npy')) @@ -123,7 +117,8 @@ def test_pardiso_fdem: A = sp.csr_matrix((data, indices, indptr), shape=(13872, 13872)) rhs = np.load(os.path.join(base_path, 'RHS.npy')) - Ainv = Pardiso(A, check_accuracy=True) + Ainv = pymatsolver.Pardiso(A, check_accuracy=True) + + sol = Ainv * rhs - with pytest.warns(PardisoTypeConversionWarning): - sol = Ainv * rhs.real \ No newline at end of file + np.testing.assert_allclose(A @ sol, rhs, atol=TOL) \ No newline at end of file