diff --git a/CHANGELOG.md b/CHANGELOG.md index eb31d3cd325..168cfb5b19d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Replaced `ci` section in `.pre-commit-config.yaml` with a new GitHub workflow with scheduled run to autoupdate the `pre-commit` configuration [#2542](https://github.com/IntelPython/dpnp/pull/2542) * FFT module is updated to perform in-place FFT in intermediate steps of ND FFT [#2543](https://github.com/IntelPython/dpnp/pull/2543) * Reused dpctl tensor include to enable experimental SYCL namespace for complex types [#2546](https://github.com/IntelPython/dpnp/pull/2546) +* Refactored backend implementation of `dpnp.linalg.solve` to use oneMKL LAPACK `gesv` directly [#2558](https://github.com/IntelPython/dpnp/pull/2558) ### Deprecated diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index 55d140c5c88..8a542eddc6f 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -2619,15 +2619,9 @@ def dpnp_solve(a, b): a_usm_arr = dpnp.get_usm_ndarray(a) b_usm_arr = dpnp.get_usm_ndarray(b) - # Due to MKLD-17226 (bug with incorrect checking ldb parameter - # in oneapi::mkl::lapack::gesv_scratchad_size that raises an error - # `invalid argument` when nrhs > n) we can not use _gesv directly. - # This w/a uses _getrf and _getrs instead - # to handle cases where nrhs > n for a.shape = (n x n) - # and b.shape = (n x nrhs). - - # oneMKL LAPACK getrf overwrites `a`. - a_h = dpnp.empty_like(a, order="C", dtype=res_type, usm_type=res_usm_type) + # oneMKL LAPACK getrs overwrites `a` and assumes fortran-like array as + # input + a_h = dpnp.empty_like(a, order="F", dtype=res_type, usm_type=res_usm_type) _manager = dpu.SequentialOrderManager[exec_q] dev_evs = _manager.submitted_events @@ -2658,39 +2652,14 @@ def dpnp_solve(a, b): ) _manager.add_event_pair(ht_ev, b_copy_ev) - n = a.shape[0] - - ipiv_h = dpnp.empty_like( - a, - shape=(n,), - dtype=dpnp.int64, - ) - dev_info_h = [0] - - # Call the LAPACK extension function _getrf - # to perform LU decomposition of the input matrix - ht_ev, getrf_ev = li._getrf( - exec_q, - a_h.get_array(), - ipiv_h.get_array(), - dev_info_h, - depends=[a_copy_ev], + # Call the LAPACK extension function _gesv to solve the system of linear + # equations with the coefficient square matrix and + # the dependent variables array. + ht_lapack_ev, gesv_ev = li._gesv( + exec_q, a_h.get_array(), b_h.get_array(), [a_copy_ev, b_copy_ev] ) - _manager.add_event_pair(ht_ev, getrf_ev) - _check_lapack_dev_info(dev_info_h) - - # Call the LAPACK extension function _getrs - # to solve the system of linear equations with an LU-factored - # coefficient square matrix, with multiple right-hand sides. - ht_ev, getrs_ev = li._getrs( - exec_q, - a_h.get_array(), - ipiv_h.get_array(), - b_h.get_array(), - depends=[b_copy_ev, getrf_ev], - ) - _manager.add_event_pair(ht_ev, getrs_ev) + _manager.add_event_pair(ht_lapack_ev, gesv_ev) return b_h