diff --git a/include/amici/model_state.h b/include/amici/model_state.h index 149c5757b8..defb12d4c0 100644 --- a/include/amici/model_state.h +++ b/include/amici/model_state.h @@ -146,71 +146,37 @@ struct ModelStateDerived { , dwdw_(other.dwdw_) , dwdx_hierarchical_(other.dwdx_hierarchical_) , dJydy_dense_(other.dJydy_dense_) { - // Update the SUNContext of all matrices - if (J_.data()) { - J_.get()->sunctx = sunctx_; + // Update the SUNContext of all SUNDIALS objects + J_.set_ctx(sunctx_); + JB_.set_ctx(sunctx_); + dxdotdw_.set_ctx(sunctx_); + dwdx_.set_ctx(sunctx_); + dwdp_.set_ctx(sunctx_); + M_.set_ctx(sunctx_); + MSparse_.set_ctx(sunctx_); + dfdx_.set_ctx(sunctx_); + dxdotdp_full.set_ctx(sunctx_); + dxdotdp_explicit.set_ctx(sunctx_); + dxdotdp_implicit.set_ctx(sunctx_); + dxdotdx_explicit.set_ctx(sunctx_); + dxdotdx_implicit.set_ctx(sunctx_); + dx_rdatadx_solver.set_ctx(sunctx_); + dx_rdatadtcl.set_ctx(sunctx_); + dtotal_cldx_rdata.set_ctx(sunctx_); + dxdotdp.set_ctx(sunctx_); + + for (auto& dJydy : dJydy_) { + dJydy.set_ctx(sunctx_); } - if (JB_.data()) { - JB_.get()->sunctx = sunctx_; + for (auto& dwdp : dwdp_hierarchical_) { + dwdp.set_ctx(sunctx_); } - if (dxdotdw_.data()) { - dxdotdw_.get()->sunctx = sunctx_; - } - if (dwdx_.data()) { - dwdx_.get()->sunctx = sunctx_; - } - if (dwdp_.data()) { - dwdp_.get()->sunctx = sunctx_; - } - if (M_.data()) { - M_.get()->sunctx = sunctx_; - } - if (MSparse_.data()) { - MSparse_.get()->sunctx = sunctx_; - } - if (dfdx_.data()) { - dfdx_.get()->sunctx = sunctx_; - } - if (dxdotdp_full.data()) { - dxdotdp_full.get()->sunctx = sunctx_; - } - if (dxdotdp_explicit.data()) { - dxdotdp_explicit.get()->sunctx = sunctx_; - } - if (dxdotdp_implicit.data()) { - dxdotdp_implicit.get()->sunctx = sunctx_; - } - if (dxdotdx_explicit.data()) { - dxdotdx_explicit.get()->sunctx = sunctx_; - } - if (dxdotdx_implicit.data()) { - dxdotdx_implicit.get()->sunctx = sunctx_; - } - if (dx_rdatadx_solver.data()) { - dx_rdatadx_solver.get()->sunctx = sunctx_; - } - if (dx_rdatadtcl.data()) { - dx_rdatadtcl.get()->sunctx = sunctx_; - } - if (dtotal_cldx_rdata.data()) { - dtotal_cldx_rdata.get()->sunctx = sunctx_; - } - for (auto const& dwdp : dwdp_hierarchical_) { - if (dwdp.data()) { - dwdp.get()->sunctx = sunctx_; - } - } - for (auto const& dwdx : dwdx_hierarchical_) { - if (dwdx.data()) { - dwdx.get()->sunctx = sunctx_; - } - } - if (dwdw_.data()) { - dwdw_.get()->sunctx = sunctx_; - } - if (dJydy_dense_.data()) { - dJydy_dense_.get()->sunctx = sunctx_; + for (auto& dwdx : dwdx_hierarchical_) { + dwdx.set_ctx(sunctx_); } + sspl_.set_ctx(sunctx_); + dwdw_.set_ctx(sunctx_); + dJydy_dense_.set_ctx(sunctx_); } /** diff --git a/include/amici/sundials_matrix_wrapper.h b/include/amici/sundials_matrix_wrapper.h index 942069a803..a7ad361acb 100644 --- a/include/amici/sundials_matrix_wrapper.h +++ b/include/amici/sundials_matrix_wrapper.h @@ -506,6 +506,19 @@ class SUNMatrixWrapper { */ SUNContext get_ctx() const; + /** + * @brief Set SUNContext + * + * Update the SUNContext of the wrapped SUNMatrix. + * + * @param ctx SUNDIALS context to set + */ + void set_ctx(SUNContext ctx) { + if (matrix_) { + matrix_->sunctx = ctx; + } + } + private: /** * @brief SUNMatrix to which all methods are applied diff --git a/include/amici/vector.h b/include/amici/vector.h index 32c436fbda..0a7648e460 100644 --- a/include/amici/vector.h +++ b/include/amici/vector.h @@ -413,6 +413,20 @@ class AmiVectorArray { */ void copy(AmiVectorArray const& other); + /** + * @brief Set SUNContext + * + * If any AmiVector is non-empty, this changes the current SUNContext of the + * associated N_Vector. If empty, do nothing. + * + * @param ctx SUNDIALS context to set + */ + void set_ctx(SUNContext ctx) { + for (auto& vec : vec_array_) { + vec.set_ctx(ctx); + } + } + private: /** main data storage */ std::vector vec_array_;