From be02a6ec3d7555f2925c866e7c8692076af44be5 Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Tue, 1 Oct 2024 13:39:30 +0200 Subject: [PATCH] Refactor SUNLinSolWrapper (#2520) * Make it consistent that SUNLinSolWrapper always holds the associated matrix * Always use SUNMatrixWrapper instead of raw SUNMatrix objects * Implement declared but missing move assignment This makes it a bit easier to finally address #1164. --- include/amici/sundials_linsol_wrapper.h | 67 +++++++++---------------- src/model.cpp | 8 +-- src/steadystateproblem.cpp | 11 ++-- src/sundials_linsol_wrapper.cpp | 67 +++++++++++++++---------- 4 files changed, 74 insertions(+), 79 deletions(-) diff --git a/include/amici/sundials_linsol_wrapper.h b/include/amici/sundials_linsol_wrapper.h index 613eb22156..d8fa1e72b8 100644 --- a/include/amici/sundials_linsol_wrapper.h +++ b/include/amici/sundials_linsol_wrapper.h @@ -33,10 +33,21 @@ class SUNLinSolWrapper { /** * @brief Wrap existing SUNLinearSolver - * @param linsol + * + * @param linsol SUNLinSolWrapper takes ownership of `linsol`. */ explicit SUNLinSolWrapper(SUNLinearSolver linsol); + /** + * @brief Wrap existing SUNLinearSolver + * + * @param linsol SUNLinSolWrapper takes ownership of `linsol`. + * @param A Matrix + */ + explicit SUNLinSolWrapper( + SUNLinearSolver linsol, SUNMatrixWrapper const& A + ); + virtual ~SUNLinSolWrapper(); /** @@ -80,26 +91,17 @@ class SUNLinSolWrapper { /** * @brief Performs any linear solver setup needed, based on an updated * system matrix A. - * @param A */ - void setup(SUNMatrix A) const; - - /** - * @brief Performs any linear solver setup needed, based on an updated - * system matrix A. - * @param A - */ - void setup(SUNMatrixWrapper const& A) const; + void setup() const; /** * @brief Solves a linear system A*x = b - * @param A * @param x A template for cloning vectors needed within the solver. * @param b * @param tol Tolerance (weighted 2-norm), iterative solvers only * @return error flag */ - int Solve(SUNMatrix A, N_Vector x, N_Vector b, realtype tol) const; + int solve(N_Vector x, N_Vector b, realtype tol) const; /** * @brief Returns the last error flag encountered within the linear solver @@ -119,7 +121,7 @@ class SUNLinSolWrapper { * @brief Get the matrix A (matrix solvers only). * @return A */ - virtual SUNMatrix getMatrix() const; + virtual SUNMatrixWrapper& getMatrix(); protected: /** @@ -131,6 +133,9 @@ class SUNLinSolWrapper { /** Wrapped solver */ SUNLinearSolver solver_{nullptr}; + + /** Matrix A for solver. */ + SUNMatrixWrapper A_; }; /** @@ -139,12 +144,12 @@ class SUNLinSolWrapper { class SUNLinSolBand : public SUNLinSolWrapper { public: /** - * @brief Create solver using existing matrix A without taking ownership of - * A. + * @brief Create solver using existing matrix A + * * @param x A template for cloning vectors needed within the solver. * @param A square matrix */ - SUNLinSolBand(N_Vector x, SUNMatrix A); + SUNLinSolBand(N_Vector x, SUNMatrixWrapper A); /** * @brief Create new band solver and matrix A. @@ -153,12 +158,6 @@ class SUNLinSolBand : public SUNLinSolWrapper { * @param lbw lower bandwidth of band matrix A */ SUNLinSolBand(AmiVector const& x, int ubw, int lbw); - - SUNMatrix getMatrix() const override; - - private: - /** Matrix A for solver, only if created by here. */ - SUNMatrixWrapper A_; }; /** @@ -171,12 +170,6 @@ class SUNLinSolDense : public SUNLinSolWrapper { * @param x A template for cloning vectors needed within the solver. */ explicit SUNLinSolDense(AmiVector const& x); - - SUNMatrix getMatrix() const override; - - private: - /** Matrix A for solver, only if created by here. */ - SUNMatrixWrapper A_; }; /** @@ -192,7 +185,7 @@ class SUNLinSolKLU : public SUNLinSolWrapper { * @param x A template for cloning vectors needed within the solver. * @param A sparse matrix */ - SUNLinSolKLU(N_Vector x, SUNMatrix A); + SUNLinSolKLU(N_Vector x, SUNMatrixWrapper A); /** * @brief Create KLU solver and matrix to operate on @@ -202,11 +195,9 @@ class SUNLinSolKLU : public SUNLinSolWrapper { * @param ordering */ SUNLinSolKLU( - AmiVector const& x, int nnz, int sparsetype, StateOrdering ordering + AmiVector const& x, int nnz, int sparsetype, StateOrdering ordering = StateOrdering::COLAMD ); - SUNMatrix getMatrix() const override; - /** * @brief Reinitializes memory and flags for a new factorization * (symbolic and numeric) to be conducted at the next solver setup call. @@ -223,10 +214,6 @@ class SUNLinSolKLU : public SUNLinSolWrapper { * @param ordering */ void setOrdering(StateOrdering ordering); - - private: - /** Sparse matrix A for solver, only if created by here. */ - SUNMatrixWrapper A_; }; #ifdef SUNDIALS_SUPERLUMT @@ -249,7 +236,7 @@ class SUNLinSolSuperLUMT : public SUNLinSolWrapper { * @param A sparse matrix * @param numThreads Number of threads to be used by SuperLUMT */ - SUNLinSolSuperLUMT(N_Vector x, SUNMatrix A, int numThreads); + SUNLinSolSuperLUMT(N_Vector x, SUNMatrixWrapper A, int numThreads); /** * @brief Create SuperLUMT solver and matrix to operate on @@ -279,18 +266,12 @@ class SUNLinSolSuperLUMT : public SUNLinSolWrapper { int numThreads ); - SUNMatrix getMatrix() const override; - /** * @brief Sets the ordering used by SuperLUMT for reducing fill in the * linear solve. * @param ordering */ void setOrdering(StateOrdering ordering); - - private: - /** Sparse matrix A for solver, only if created by here. */ - SUNMatrixWrapper A; }; #endif diff --git a/src/model.cpp b/src/model.cpp index 0c65cefe9d..4658f41008 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -2221,10 +2221,12 @@ void Model::fdJydy(int const it, AmiVector const& x, ExpData const& edata) { BLASLayout::colMajor, BLASTranspose::noTrans, BLASTranspose::noTrans, nJ, ny, ny, 1.0, &derived_state_.dJydsigma_.at(iyt * nJ * ny), nJ, - derived_state_.dsigmaydy_.data(), ny, 1.0, derived_state_.dJydy_dense_.data(), nJ + derived_state_.dsigmaydy_.data(), ny, 1.0, + derived_state_.dJydy_dense_.data(), nJ ); - auto tmp_sparse = SUNMatrixWrapper(derived_state_.dJydy_dense_, 0.0, CSC_MAT); + auto tmp_sparse + = SUNMatrixWrapper(derived_state_.dJydy_dense_, 0.0, CSC_MAT); auto ret = SUNMatScaleAdd( 1.0, derived_state_.dJydy_.at(iyt), tmp_sparse ); @@ -3079,7 +3081,7 @@ std::vector Model::get_trigger_timepoints() const { return trigger_timepoints; } -void Model::set_steadystate_mask(const std::vector &mask) { +void Model::set_steadystate_mask(std::vector const& mask) { if (mask.size() == 0) { steadystate_mask_.clear(); return; diff --git a/src/steadystateproblem.cpp b/src/steadystateproblem.cpp index 23e31b457e..5feb319d9a 100644 --- a/src/steadystateproblem.cpp +++ b/src/steadystateproblem.cpp @@ -552,8 +552,7 @@ SteadystateProblem::getWrms(Model& model, SensitivityMethod sensi_method) { "steady state computations. Stopping." ); wrms = getWrmsNorm( - xQB_, xQBdot_, steadystate_mask_, atol_quad_, - rtol_quad_, ewtQB_ + xQB_, xQBdot_, steadystate_mask_, atol_quad_, rtol_quad_, ewtQB_ ); } else { /* If we're doing a forward simulation (with or without sensitivities: @@ -563,8 +562,8 @@ SteadystateProblem::getWrms(Model& model, SensitivityMethod sensi_method) { else updateRightHandSide(model); wrms = getWrmsNorm( - state_.x, newton_step_conv_ ? delta_ : xdot_, - steadystate_mask_, atol_, rtol_, ewt_ + state_.x, newton_step_conv_ ? delta_ : xdot_, steadystate_mask_, + atol_, rtol_, ewt_ ); } return wrms; @@ -586,8 +585,8 @@ realtype SteadystateProblem::getWrmsFSA(Model& model) { if (newton_step_conv_) newton_solver_->solveLinearSystem(xdot_); wrms = getWrmsNorm( - state_.sx[ip], xdot_, steadystate_mask_, atol_sensi_, - rtol_sensi_, ewt_ + state_.sx[ip], xdot_, steadystate_mask_, atol_sensi_, rtol_sensi_, + ewt_ ); /* ideally this function would report the maximum of all wrms over all ip, but for practical purposes we can just report the wrms for diff --git a/src/sundials_linsol_wrapper.cpp b/src/sundials_linsol_wrapper.cpp index 6836949cb8..7808620227 100644 --- a/src/sundials_linsol_wrapper.cpp +++ b/src/sundials_linsol_wrapper.cpp @@ -9,6 +9,12 @@ namespace amici { SUNLinSolWrapper::SUNLinSolWrapper(SUNLinearSolver linsol) : solver_(linsol) {} +SUNLinSolWrapper::SUNLinSolWrapper( + SUNLinearSolver linsol, SUNMatrixWrapper const& A +) + : solver_(linsol) + , A_(A) {} + SUNLinSolWrapper::~SUNLinSolWrapper() { if (solver_) SUNLinSolFree(solver_); @@ -16,6 +22,13 @@ SUNLinSolWrapper::~SUNLinSolWrapper() { SUNLinSolWrapper::SUNLinSolWrapper(SUNLinSolWrapper&& other) noexcept { std::swap(solver_, other.solver_); + std::swap(A_, other.A_); +} + +SUNLinSolWrapper& SUNLinSolWrapper::operator=(SUNLinSolWrapper&& other) noexcept { + std::swap(solver_, other.solver_); + std::swap(A_, other.A_); + return *this; } SUNLinearSolver SUNLinSolWrapper::get() const { return solver_; } @@ -31,19 +44,14 @@ int SUNLinSolWrapper::initialize() { return res; } -void SUNLinSolWrapper::setup(SUNMatrix A) const { - auto res = SUNLinSolSetup(solver_, A); +void SUNLinSolWrapper::setup() const { + auto res = SUNLinSolSetup(solver_, A_.get()); if (res != SUNLS_SUCCESS) throw AmiException("Solver setup failed with code %d", res); } -void SUNLinSolWrapper::setup(SUNMatrixWrapper const& A) const { - return setup(A.get()); -} - -int SUNLinSolWrapper::Solve(SUNMatrix A, N_Vector x, N_Vector b, realtype tol) - const { - return SUNLinSolSolve(solver_, A, x, b, tol); +int SUNLinSolWrapper::solve(N_Vector x, N_Vector b, realtype tol) const { + return SUNLinSolSolve(solver_, A_.get(), x, b, tol); } long SUNLinSolWrapper::getLastFlag() const { @@ -54,7 +62,7 @@ int SUNLinSolWrapper::space(long* lenrwLS, long* leniwLS) const { return SUNLinSolSpace(solver_, lenrwLS, leniwLS); } -SUNMatrix SUNLinSolWrapper::getMatrix() const { return nullptr; } +SUNMatrixWrapper& SUNLinSolWrapper::getMatrix() { return A_; } SUNNonLinSolWrapper::SUNNonLinSolWrapper(SUNNonlinearSolver sol) : solver(sol) {} @@ -153,31 +161,29 @@ void SUNNonLinSolWrapper::initialize() { ); } -SUNLinSolBand::SUNLinSolBand(N_Vector x, SUNMatrix A) +SUNLinSolBand::SUNLinSolBand(N_Vector x, SUNMatrixWrapper A) : SUNLinSolWrapper(SUNLinSol_Band(x, A)) { if (!solver_) throw AmiException("Failed to create solver."); } SUNLinSolBand::SUNLinSolBand(AmiVector const& x, int ubw, int lbw) - : A_(SUNMatrixWrapper(x.getLength(), ubw, lbw)) { + : SUNLinSolWrapper(nullptr, SUNMatrixWrapper(x.getLength(), ubw, lbw)) { solver_ = SUNLinSol_Band(const_cast(x.getNVector()), A_); if (!solver_) throw AmiException("Failed to create solver."); } -SUNMatrix SUNLinSolBand::getMatrix() const { return A_.get(); } - SUNLinSolDense::SUNLinSolDense(AmiVector const& x) - : A_(SUNMatrixWrapper(x.getLength(), x.getLength())) { + : SUNLinSolWrapper( + nullptr, SUNMatrixWrapper(x.getLength(), x.getLength()) + ) { solver_ = SUNLinSol_Dense(const_cast(x.getNVector()), A_); if (!solver_) throw AmiException("Failed to create solver."); } -SUNMatrix SUNLinSolDense::getMatrix() const { return A_.get(); } - -SUNLinSolKLU::SUNLinSolKLU(N_Vector x, SUNMatrix A) +SUNLinSolKLU::SUNLinSolKLU(N_Vector x, SUNMatrixWrapper A) : SUNLinSolWrapper(SUNLinSol_KLU(x, A)) { if (!solver_) throw AmiException("Failed to create solver."); @@ -186,7 +192,10 @@ SUNLinSolKLU::SUNLinSolKLU(N_Vector x, SUNMatrix A) SUNLinSolKLU::SUNLinSolKLU( AmiVector const& x, int nnz, int sparsetype, StateOrdering ordering ) - : A_(SUNMatrixWrapper(x.getLength(), x.getLength(), nnz, sparsetype)) { + : SUNLinSolWrapper( + nullptr, + SUNMatrixWrapper(x.getLength(), x.getLength(), nnz, sparsetype) + ) { solver_ = SUNLinSol_KLU(const_cast(x.getNVector()), A_); if (!solver_) throw AmiException("Failed to create solver."); @@ -194,8 +203,6 @@ SUNLinSolKLU::SUNLinSolKLU( setOrdering(ordering); } -SUNMatrix SUNLinSolKLU::getMatrix() const { return A_.get(); } - void SUNLinSolKLU::reInit(int nnz, int reinit_type) { int status = SUNLinSol_KLUReInit(solver_, A_, nnz, reinit_type); if (status != SUNLS_SUCCESS) @@ -413,8 +420,10 @@ int SUNNonLinSolFixedPoint::getSysFn(SUNNonlinSolSysFn* SysFn) const { #ifdef SUNDIALS_SUPERLUMT -SUNLinSolSuperLUMT::SUNLinSolSuperLUMT(N_Vector x, SUNMatrix A, int numThreads) - : SUNLinSolWrapper(SUNLinSol_SuperLUMT(x, A, numThreads)) { +SUNLinSolSuperLUMT::SUNLinSolSuperLUMT( + N_Vector x, SUNMatrixWrapper A, int numThreads +) + : SUNLinSolWrapper(SUNLinSol_SuperLUMT(x, A, numThreads), A) { if (!solver) throw AmiException("Failed to create solver."); } @@ -423,7 +432,10 @@ SUNLinSolSuperLUMT::SUNLinSolSuperLUMT( AmiVector const& x, int nnz, int sparsetype, SUNLinSolSuperLUMT::StateOrdering ordering ) - : A(SUNMatrixWrapper(x.getLength(), x.getLength(), nnz, sparsetype)) { + : SUNLinSolWrapper( + nullptr, + SUNMatrixWrapper(x.getLength(), x.getLength(), nnz, sparsetype) + ) { int numThreads = 1; if (auto env = std::getenv("AMICI_SUPERLUMT_NUM_THREADS")) { numThreads = std::max(1, std::stoi(env)); @@ -440,7 +452,10 @@ SUNLinSolSuperLUMT::SUNLinSolSuperLUMT( AmiVector const& x, int nnz, int sparsetype, StateOrdering ordering, int numThreads ) - : A(SUNMatrixWrapper(x.getLength(), x.getLength(), nnz, sparsetype)) { + : SUNLinSolWrapper( + nullptr, + SUNMatrixWrapper(x.getLength(), x.getLength(), nnz, sparsetype) + ) { solver = SUNLinSol_SuperLUMT(x.getNVector(), A.get(), numThreads); if (!solver) throw AmiException("Failed to create solver."); @@ -448,8 +463,6 @@ SUNLinSolSuperLUMT::SUNLinSolSuperLUMT( setOrdering(ordering); } -SUNMatrix SUNLinSolSuperLUMT::getMatrix() const { return A.get(); } - void SUNLinSolSuperLUMT::setOrdering(StateOrdering ordering) { auto status = SUNLinSol_SuperLUMTSetOrdering(solver, static_cast(ordering));