diff --git a/tests/test_Pardiso.py b/tests/test_Pardiso.py index a6f0887..4235cea 100644 --- a/tests/test_Pardiso.py +++ b/tests/test_Pardiso.py @@ -25,23 +25,63 @@ def test_mat_data(): @pytest.mark.parametrize('transpose', [True, False]) @pytest.mark.parametrize('dtype', [np.float64, np.complex128]) -@pytest.mark.parametrize('symmetric', [True, False]) -def test_solve(test_mat_data, dtype, transpose, symmetric): +@pytest.mark.parametrize('symmetry', ["S", "H", None]) +def test_solve(test_mat_data, dtype, transpose, symmetry): + A, sol = test_mat_data - sol = sol.astype(dtype) - A = A.astype(dtype) - if not symmetric: + + if symmetry is None: D = sp.diags(np.linspace(2, 3, A.shape[0])) A = D @ A - rhs = A @ sol + symmetric = False + hermitian = False + elif symmetry == "H": + D = sp.diags(np.linspace(2, 3, A.shape[0])) + if np.issubdtype(dtype, np.complexfloating): + D = D + 1j * sp.diags(np.linspace(3, 4, A.shape[0])) + A = D @ A @ D.T.conjugate() + symmetric = False + hermitian = True + else: + symmetric = True + hermitian = False + + sol = sol.astype(dtype) + A = A.astype(dtype) + if transpose: - Ainv = pymatsolver.Pardiso(A.T, is_symmetric=symmetric).T + rhs = A.T @ sol + Ainv = pymatsolver.Pardiso(A, is_symmetric=symmetric, is_hermitian=hermitian).T else: - Ainv = pymatsolver.Pardiso(A, is_symmetric=symmetric) + rhs = A @ sol + Ainv = pymatsolver.Pardiso(A, is_symmetric=symmetric, is_hermitian=hermitian) for i in range(rhs.shape[1]): npt.assert_allclose(Ainv * rhs[:, i], sol[:, i], atol=TOL) npt.assert_allclose(Ainv * rhs, sol, atol=TOL) +@pytest.mark.parametrize('transpose', [True, False]) +@pytest.mark.parametrize('dtype', [np.float64, np.complex128]) +def test_pardiso_positive_definite(dtype, transpose): + n = 5 + if dtype == np.float64: + L = sp.diags([1, -1], [0, -1], shape=(n, n)) + else: + L = sp.diags([1, -1j], [0, -1], shape=(n, n)) + D = sp.diags(np.linspace(1, 2, n)) + A_pd = L @ D @ (L.T.conjugate()) + + sol = np.linspace(0.9, 1.1, n) + + is_symmetric = dtype == np.float64 + if transpose: + rhs = A_pd.T @ sol + Ainv = pymatsolver.Pardiso(A_pd, is_symmetric=is_symmetric, is_hermitian=True, is_positive_definite=True).T + else: + rhs = A_pd @ sol + Ainv = pymatsolver.Pardiso(A_pd, is_symmetric=is_symmetric, is_hermitian=True, is_positive_definite=True) + + npt.assert_allclose(Ainv @ rhs, sol) + def test_refactor(test_mat_data): A, sol = test_mat_data