Skip to content

Commit

Permalink
Fix CSC/QR memory leak
Browse files Browse the repository at this point in the history
  • Loading branch information
asistradition committed Oct 23, 2022
1 parent 6babfe0 commit 538668a
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 52 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
### Version 0.8.3

* Explicit error message when complex data is put into the QR solver
* Fix edge condition memory leak in the QR solver

### Version 0.8.2

Expand Down
113 changes: 66 additions & 47 deletions sparse_dot_mkl/_sparse_qr_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
import ctypes as _ctypes
import scipy.sparse as _spsparse

# Keyed by bool for double-precision
SOLVE_FUNCS = {
True: MKL._mkl_sparse_d_qr_solve,
False: MKL._mkl_sparse_s_qr_solve
}

# Keyed by bool for double-precision
FACTORIZE_FUNCS = {
True: MKL._mkl_sparse_d_qr_factorize,
False: MKL._mkl_sparse_s_qr_factorize
Expand All @@ -31,68 +33,77 @@ def _sparse_qr(
"""
Solve AX = B for X
:param matrix_a: Sparse matrix A
:type matrix_a: scipy.sparse.csr_matrix
:param matrix_b: Dense matrix B
:param matrix_a: Sparse matrix A as CSR or CSC [M x N]
:type matrix_a: scipy.sparse.spmatrix
:param matrix_b: Dense matrix B [M x 1]
:type matrix_b: numpy.ndarray
:return: Dense matrix X
:return: Dense matrix X [N x 1]
:rtype: numpy.ndarray
"""

mkl_a, dbl, cplx = _create_mkl_sparse(matrix_a)
layout_b, ld_b = _get_numpy_layout(matrix_b)
_mkl_handles = []

output_shape = matrix_a.shape[1], matrix_b.shape[1]
try:
mkl_a, dbl, _ = _create_mkl_sparse(matrix_a)
_mkl_handles.append(mkl_a)

if _spsparse.isspmatrix_csc(matrix_a):
mkl_a = _convert_to_csr(mkl_a)
layout_b, ld_b = _get_numpy_layout(matrix_b)

# QR Reorder ##
ret_val_r = MKL._mkl_sparse_qr_reorder(mkl_a, matrix_descr())
output_shape = matrix_a.shape[1], matrix_b.shape[1]

# Check return
_check_return_value(ret_val_r, "mkl_sparse_qr_reorder")
# Convert a CSC matrix to CSR
if _spsparse.isspmatrix_csc(matrix_a):
mkl_a = _convert_to_csr(mkl_a)
_mkl_handles.append(mkl_a)

# QR Factorize ##
factorize_func = FACTORIZE_FUNCS[dbl]
# QR Reorder ##
ret_val_r = MKL._mkl_sparse_qr_reorder(mkl_a, matrix_descr())

ret_val_f = factorize_func(mkl_a, None)
# Check return
_check_return_value(ret_val_r, "mkl_sparse_qr_reorder")

# Check return
_check_return_value(ret_val_f, factorize_func.__name__)
# QR Factorize ##
factorize_func = FACTORIZE_FUNCS[dbl]

# QR Solve ##
output_dtype = np.float64 if dbl else np.float32
output_ctype = _ctypes.c_double if dbl else _ctypes.c_float
ret_val_f = factorize_func(mkl_a, None)

output_arr = np.zeros(
output_shape,
dtype=output_dtype,
order="C" if layout_b == LAYOUT_CODE_C else "F"
)
# Check return
_check_return_value(ret_val_f, factorize_func.__name__)

layout_out, ld_out = _get_numpy_layout(output_arr)
# QR Solve ##
output_dtype = np.float64 if dbl else np.float32
output_ctype = _ctypes.c_double if dbl else _ctypes.c_float

solve_func = SOLVE_FUNCS[dbl]
output_arr = np.zeros(
output_shape,
dtype=output_dtype,
order="C" if layout_b == LAYOUT_CODE_C else "F"
)

layout_out, ld_out = _get_numpy_layout(output_arr)

ret_val_s = solve_func(
10,
mkl_a,
None,
layout_b,
output_shape[1],
output_arr.ctypes.data_as(_ctypes.POINTER(output_ctype)),
ld_out,
matrix_b,
ld_b
)
solve_func = SOLVE_FUNCS[dbl]

# Check return
_check_return_value(ret_val_s, solve_func.__name__)
ret_val_s = solve_func(
10,
mkl_a,
None,
layout_b,
output_shape[1],
output_arr.ctypes.data_as(_ctypes.POINTER(output_ctype)),
ld_out,
matrix_b,
ld_b
)

_destroy_mkl_handle(mkl_a)
# Check return
_check_return_value(ret_val_s, solve_func.__name__)

return output_arr
return output_arr

finally:
for _handle in _mkl_handles:
_destroy_mkl_handle(_handle)


def sparse_qr_solver(
Expand All @@ -101,11 +112,19 @@ def sparse_qr_solver(
cast=False
):
"""
Run the MKL QR solver for Ax=B
and return x
:param matrix_a:
:param matrix_b:
:param cast:
:return:
:param matrix_a: Sparse matrix A as CSR or CSC [M x N]
:type matrix_a: scipy.sparse.spmatrix
:param matrix_b: Dense matrix B [M x 1]
:type matrix_b: numpy.ndarray
:param cast: Convert data to compatible floats and
convert CSC matrix to CSR matrix if necessary
:raise ValueError: Raise a ValueError if the input matrices
cannot be multiplied
:return: Dense matrix X [N x 1]
:rtype: numpy.ndarray
"""

if _spsparse.isspmatrix_csc(matrix_a) and not cast:
Expand Down
18 changes: 13 additions & 5 deletions sparse_dot_mkl/tests/test_qr_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,26 +40,34 @@ def test_sparse_solver_cast_CSC(self):
mat3 = sparse_qr_solve_mkl(self.mat1.tocsc(), self.mat2, cast=True)
npt.assert_array_almost_equal(self.mat3, mat3)

def test_sparse_solver_cast_CSC_Forder(self):
mat3 = sparse_qr_solve_mkl(
self.mat1.tocsc(),
np.array(self.mat2, order="F"),
cast=True
)
npt.assert_array_almost_equal(self.mat3, mat3)

def test_sparse_solver_1d_d(self):
mat3 = sparse_qr_solve_mkl(self.mat1, self.mat2.ravel())
npt.assert_array_almost_equal(self.mat3.ravel(), mat3)

def test_solver_guard_errors(self):
with self.assertRaises(ValueError):
mat3 = sparse_qr_solve_mkl(self.mat1, self.mat2.T)
_ = sparse_qr_solve_mkl(self.mat1, self.mat2.T)

with self.assertRaises(ValueError):
mat3 = sparse_qr_solve_mkl(self.mat1.tocsc(), self.mat2)
_ = sparse_qr_solve_mkl(self.mat1.tocsc(), self.mat2)

with self.assertRaises(ValueError):
mat3 = sparse_qr_solve_mkl(self.mat1.tocoo(), self.mat2, cast=True)
_ = sparse_qr_solve_mkl(self.mat1.tocoo(), self.mat2, cast=True)

with self.assertRaises(ValueError) as raised:
mat4 = sparse_qr_solve_mkl(self.mat1.astype(np.cdouble), self.mat2)
_ = sparse_qr_solve_mkl(self.mat1.astype(np.cdouble), self.mat2)
self.assertEqual(
str(raised.exception),
"Complex datatypes are not supported"
)

with self.assertRaises(ValueError):
mat5 = sparse_qr_solve_mkl(self.mat1.astype(np.cdouble), self.mat2)
_ = sparse_qr_solve_mkl(self.mat1.astype(np.csingle), self.mat2)

0 comments on commit 538668a

Please sign in to comment.