From 538668a970f57b9ef83d7890f92dc6623c4f2229 Mon Sep 17 00:00:00 2001 From: asistradition Date: Sun, 23 Oct 2022 11:00:35 -0400 Subject: [PATCH] Fix CSC/QR memory leak --- CHANGELOG.md | 1 + sparse_dot_mkl/_sparse_qr_solver.py | 113 +++++++++++++++---------- sparse_dot_mkl/tests/test_qr_solver.py | 18 ++-- 3 files changed, 80 insertions(+), 52 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 362e60e..531ee53 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ ### Version 0.8.3 * Explicit error message when complex data is put into the QR solver +* Fix edge condition memory leak in the QR solver ### Version 0.8.2 diff --git a/sparse_dot_mkl/_sparse_qr_solver.py b/sparse_dot_mkl/_sparse_qr_solver.py index abc02b1..baadc5a 100644 --- a/sparse_dot_mkl/_sparse_qr_solver.py +++ b/sparse_dot_mkl/_sparse_qr_solver.py @@ -14,11 +14,13 @@ import ctypes as _ctypes import scipy.sparse as _spsparse +# Keyed by bool for double-precision SOLVE_FUNCS = { True: MKL._mkl_sparse_d_qr_solve, False: MKL._mkl_sparse_s_qr_solve } +# Keyed by bool for double-precision FACTORIZE_FUNCS = { True: MKL._mkl_sparse_d_qr_factorize, False: MKL._mkl_sparse_s_qr_factorize @@ -31,68 +33,77 @@ def _sparse_qr( """ Solve AX = B for X - :param matrix_a: Sparse matrix A - :type matrix_a: scipy.sparse.csr_matrix - :param matrix_b: Dense matrix B + :param matrix_a: Sparse matrix A as CSR or CSC [M x N] + :type matrix_a: scipy.sparse.spmatrix + :param matrix_b: Dense matrix B [M x 1] :type matrix_b: numpy.ndarray - :return: Dense matrix X + :return: Dense matrix X [N x 1] :rtype: numpy.ndarray """ - mkl_a, dbl, cplx = _create_mkl_sparse(matrix_a) - layout_b, ld_b = _get_numpy_layout(matrix_b) + _mkl_handles = [] - output_shape = matrix_a.shape[1], matrix_b.shape[1] + try: + mkl_a, dbl, _ = _create_mkl_sparse(matrix_a) + _mkl_handles.append(mkl_a) - if _spsparse.isspmatrix_csc(matrix_a): - mkl_a = _convert_to_csr(mkl_a) + layout_b, ld_b = _get_numpy_layout(matrix_b) - # QR Reorder ## - ret_val_r = MKL._mkl_sparse_qr_reorder(mkl_a, matrix_descr()) + output_shape = matrix_a.shape[1], matrix_b.shape[1] - # Check return - _check_return_value(ret_val_r, "mkl_sparse_qr_reorder") + # Convert a CSC matrix to CSR + if _spsparse.isspmatrix_csc(matrix_a): + mkl_a = _convert_to_csr(mkl_a) + _mkl_handles.append(mkl_a) - # QR Factorize ## - factorize_func = FACTORIZE_FUNCS[dbl] + # QR Reorder ## + ret_val_r = MKL._mkl_sparse_qr_reorder(mkl_a, matrix_descr()) - ret_val_f = factorize_func(mkl_a, None) + # Check return + _check_return_value(ret_val_r, "mkl_sparse_qr_reorder") - # Check return - _check_return_value(ret_val_f, factorize_func.__name__) + # QR Factorize ## + factorize_func = FACTORIZE_FUNCS[dbl] - # QR Solve ## - output_dtype = np.float64 if dbl else np.float32 - output_ctype = _ctypes.c_double if dbl else _ctypes.c_float + ret_val_f = factorize_func(mkl_a, None) - output_arr = np.zeros( - output_shape, - dtype=output_dtype, - order="C" if layout_b == LAYOUT_CODE_C else "F" - ) + # Check return + _check_return_value(ret_val_f, factorize_func.__name__) - layout_out, ld_out = _get_numpy_layout(output_arr) + # QR Solve ## + output_dtype = np.float64 if dbl else np.float32 + output_ctype = _ctypes.c_double if dbl else _ctypes.c_float - solve_func = SOLVE_FUNCS[dbl] + output_arr = np.zeros( + output_shape, + dtype=output_dtype, + order="C" if layout_b == LAYOUT_CODE_C else "F" + ) + + layout_out, ld_out = _get_numpy_layout(output_arr) - ret_val_s = solve_func( - 10, - mkl_a, - None, - layout_b, - output_shape[1], - output_arr.ctypes.data_as(_ctypes.POINTER(output_ctype)), - ld_out, - matrix_b, - ld_b - ) + solve_func = SOLVE_FUNCS[dbl] - # Check return - _check_return_value(ret_val_s, solve_func.__name__) + ret_val_s = solve_func( + 10, + mkl_a, + None, + layout_b, + output_shape[1], + output_arr.ctypes.data_as(_ctypes.POINTER(output_ctype)), + ld_out, + matrix_b, + ld_b + ) - _destroy_mkl_handle(mkl_a) + # Check return + _check_return_value(ret_val_s, solve_func.__name__) - return output_arr + return output_arr + + finally: + for _handle in _mkl_handles: + _destroy_mkl_handle(_handle) def sparse_qr_solver( @@ -101,11 +112,19 @@ def sparse_qr_solver( cast=False ): """ + Run the MKL QR solver for Ax=B + and return x - :param matrix_a: - :param matrix_b: - :param cast: - :return: + :param matrix_a: Sparse matrix A as CSR or CSC [M x N] + :type matrix_a: scipy.sparse.spmatrix + :param matrix_b: Dense matrix B [M x 1] + :type matrix_b: numpy.ndarray + :param cast: Convert data to compatible floats and + convert CSC matrix to CSR matrix if necessary + :raise ValueError: Raise a ValueError if the input matrices + cannot be multiplied + :return: Dense matrix X [N x 1] + :rtype: numpy.ndarray """ if _spsparse.isspmatrix_csc(matrix_a) and not cast: diff --git a/sparse_dot_mkl/tests/test_qr_solver.py b/sparse_dot_mkl/tests/test_qr_solver.py index 88874f6..1acdafd 100644 --- a/sparse_dot_mkl/tests/test_qr_solver.py +++ b/sparse_dot_mkl/tests/test_qr_solver.py @@ -40,26 +40,34 @@ def test_sparse_solver_cast_CSC(self): mat3 = sparse_qr_solve_mkl(self.mat1.tocsc(), self.mat2, cast=True) npt.assert_array_almost_equal(self.mat3, mat3) + def test_sparse_solver_cast_CSC_Forder(self): + mat3 = sparse_qr_solve_mkl( + self.mat1.tocsc(), + np.array(self.mat2, order="F"), + cast=True + ) + npt.assert_array_almost_equal(self.mat3, mat3) + def test_sparse_solver_1d_d(self): mat3 = sparse_qr_solve_mkl(self.mat1, self.mat2.ravel()) npt.assert_array_almost_equal(self.mat3.ravel(), mat3) def test_solver_guard_errors(self): with self.assertRaises(ValueError): - mat3 = sparse_qr_solve_mkl(self.mat1, self.mat2.T) + _ = sparse_qr_solve_mkl(self.mat1, self.mat2.T) with self.assertRaises(ValueError): - mat3 = sparse_qr_solve_mkl(self.mat1.tocsc(), self.mat2) + _ = sparse_qr_solve_mkl(self.mat1.tocsc(), self.mat2) with self.assertRaises(ValueError): - mat3 = sparse_qr_solve_mkl(self.mat1.tocoo(), self.mat2, cast=True) + _ = sparse_qr_solve_mkl(self.mat1.tocoo(), self.mat2, cast=True) with self.assertRaises(ValueError) as raised: - mat4 = sparse_qr_solve_mkl(self.mat1.astype(np.cdouble), self.mat2) + _ = sparse_qr_solve_mkl(self.mat1.astype(np.cdouble), self.mat2) self.assertEqual( str(raised.exception), "Complex datatypes are not supported" ) with self.assertRaises(ValueError): - mat5 = sparse_qr_solve_mkl(self.mat1.astype(np.cdouble), self.mat2) + _ = sparse_qr_solve_mkl(self.mat1.astype(np.csingle), self.mat2)