Skip to content

Commit

Permalink
Refactor: Move spl_ from ModelState to ModelStateDerived (#2505)
Browse files Browse the repository at this point in the history
* Move mutable Model members to ModelStateDerived
* Move ModelState::spl_ to ModelStateDerived

Closes #2504
  • Loading branch information
dweindl authored Sep 26, 2024
1 parent 0b4d405 commit 6b59490
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 48 deletions.
9 changes: 0 additions & 9 deletions include/amici/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -2108,15 +2108,6 @@ class Model : public AbstractModel, public ModelDimensions {
realtype min_sigma_{50.0};

private:
/** Sparse dwdp implicit temporary storage (shape `ndwdp`) */
mutable std::vector<SUNMatrixWrapper> dwdp_hierarchical_;

/** Sparse dwdw temporary storage (shape `ndwdw`) */
mutable SUNMatrixWrapper dwdw_;

/** Sparse dwdx implicit temporary storage (shape `ndwdx`) */
mutable std::vector<SUNMatrixWrapper> dwdx_hierarchical_;

/** Recursion */
int w_recursion_depth_{0};

Expand Down
15 changes: 12 additions & 3 deletions include/amici/model_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,6 @@ struct ModelState {
* (dimension: nplist)
*/
std::vector<int> plist;

/** temporary storage for spline values */
std::vector<realtype> spl_;
};

inline bool operator==(ModelState const& a, ModelState const& b) {
Expand Down Expand Up @@ -323,6 +320,18 @@ struct ModelStateDerived {
/** temporary storage of positified state variables according to
* stateIsNonNegative (dimension: `nx_solver`) */
AmiVector x_pos_tmp_{0};

/** temporary storage for spline values */
std::vector<realtype> spl_;

/** Sparse dwdp implicit temporary storage (shape `ndwdp`) */
std::vector<SUNMatrixWrapper> dwdp_hierarchical_;

/** Sparse dwdw temporary storage (shape `ndwdw`) */
SUNMatrixWrapper dwdw_;

/** Sparse dwdx implicit temporary storage (shape `ndwdx`) */
std::vector<SUNMatrixWrapper> dwdx_hierarchical_;
};

/**
Expand Down
72 changes: 36 additions & 36 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,27 +222,27 @@ Model::Model(
*/
if (pythonGenerated) {

dwdw_ = SUNMatrixWrapper(nw, nw, ndwdw, CSC_MAT);
derived_state_.dwdw_ = SUNMatrixWrapper(nw, nw, ndwdw, CSC_MAT);
// size dynamically adapted for dwdx_ and dwdp_
derived_state_.dwdx_ = SUNMatrixWrapper(nw, nx_solver, 0, CSC_MAT);
derived_state_.dwdp_ = SUNMatrixWrapper(nw, np(), 0, CSC_MAT);

for (int irec = 0; irec <= w_recursion_depth_; ++irec) {
/* for the first element we know the exact size, while for all
others we guess the size*/
dwdp_hierarchical_.emplace_back(
derived_state_.dwdp_hierarchical_.emplace_back(
SUNMatrixWrapper(nw, np(), irec * ndwdw + ndwdp, CSC_MAT)
);
dwdx_hierarchical_.emplace_back(
derived_state_.dwdx_hierarchical_.emplace_back(
SUNMatrixWrapper(nw, nx_solver, irec * ndwdw + ndwdx, CSC_MAT)
);
}
assert(
gsl::narrow<int>(dwdp_hierarchical_.size())
gsl::narrow<int>(derived_state_.dwdp_hierarchical_.size())
== w_recursion_depth_ + 1
);
assert(
gsl::narrow<int>(dwdx_hierarchical_.size())
gsl::narrow<int>(derived_state_.dwdx_hierarchical_.size())
== w_recursion_depth_ + 1
);

Expand Down Expand Up @@ -381,7 +381,7 @@ void Model::initializeSplines() {
splines_ = fcreate_splines(
state_.unscaledParameters.data(), state_.fixedParameters.data()
);
state_.spl_.resize(splines_.size(), 0.0);
derived_state_.spl_.resize(splines_.size(), 0.0);
for (auto& spline : splines_) {
spline.compute_coefficients();
}
Expand Down Expand Up @@ -2107,7 +2107,7 @@ void Model::fdydp(realtype const t, AmiVector const& x) {
state_.unscaledParameters.data(), state_.fixedParameters.data(),
state_.h.data(), plist(ip), derived_state_.w_.data(),
state_.total_cl.data(), state_.stotal_cl.data(),
state_.spl_.data(), derived_state_.sspl_.data()
derived_state_.spl_.data(), derived_state_.sspl_.data()
);
} else {
fdydp(
Expand Down Expand Up @@ -2864,7 +2864,7 @@ void Model::fdJrzdsigma(

void Model::fspl(realtype const t) {
for (int ispl = 0; ispl < nspl; ispl++)
state_.spl_[ispl] = splines_[ispl].get_value(t);
derived_state_.spl_[ispl] = splines_[ispl].get_value(t);
}

void Model::fsspl(realtype const t) {
Expand All @@ -2873,7 +2873,7 @@ void Model::fsspl(realtype const t) {
for (int ip = 0; ip < nplist(); ip++) {
for (int ispl = 0; ispl < nspl; ispl++)
sspl_data[ispl + nspl * plist(ip)]
= splines_[ispl].get_sensitivity(t, ip, state_.spl_[ispl]);
= splines_[ispl].get_sensitivity(t, ip, derived_state_.spl_[ispl]);
}
}

Expand All @@ -2884,7 +2884,7 @@ void Model::fw(realtype const t, realtype const* x, bool include_static) {
fspl(t);
fw(derived_state_.w_.data(), t, x, state_.unscaledParameters.data(),
state_.fixedParameters.data(), state_.h.data(), state_.total_cl.data(),
state_.spl_.data(), include_static);
derived_state_.spl_.data(), include_static);

if (always_check_finite_) {
checkFinite(derived_state_.w_, ModelQuantity::w, t);
Expand All @@ -2903,26 +2903,26 @@ void Model::fdwdp(realtype const t, realtype const* x, bool include_static) {
fsspl(t);
fdwdw(t, x, include_static);
if (include_static) {
dwdp_hierarchical_.at(0).zero();
fdwdp_colptrs(dwdp_hierarchical_.at(0));
fdwdp_rowvals(dwdp_hierarchical_.at(0));
derived_state_.dwdp_hierarchical_.at(0).zero();
fdwdp_colptrs(derived_state_.dwdp_hierarchical_.at(0));
fdwdp_rowvals(derived_state_.dwdp_hierarchical_.at(0));
}
fdwdp(
dwdp_hierarchical_.at(0).data(), t, x,
derived_state_.dwdp_hierarchical_.at(0).data(), t, x,
state_.unscaledParameters.data(), state_.fixedParameters.data(),
state_.h.data(), derived_state_.w_.data(), state_.total_cl.data(),
state_.stotal_cl.data(), state_.spl_.data(),
state_.stotal_cl.data(), derived_state_.spl_.data(),
derived_state_.sspl_.data(), include_static
);

for (int irecursion = 1; irecursion <= w_recursion_depth_;
irecursion++) {
dwdw_.sparse_multiply(
dwdp_hierarchical_.at(irecursion),
dwdp_hierarchical_.at(irecursion - 1)
derived_state_.dwdw_.sparse_multiply(
derived_state_.dwdp_hierarchical_.at(irecursion),
derived_state_.dwdp_hierarchical_.at(irecursion - 1)
);
}
derived_state_.dwdp_.sparse_sum(dwdp_hierarchical_);
derived_state_.dwdp_.sparse_sum(derived_state_.dwdp_hierarchical_);

} else {
if (!derived_state_.dwdp_.capacity())
Expand All @@ -2932,7 +2932,7 @@ void Model::fdwdp(realtype const t, realtype const* x, bool include_static) {
derived_state_.dwdp_.data(), t, x, state_.unscaledParameters.data(),
state_.fixedParameters.data(), state_.h.data(),
derived_state_.w_.data(), state_.total_cl.data(),
state_.stotal_cl.data(), state_.spl_.data(),
state_.stotal_cl.data(), derived_state_.spl_.data(),
derived_state_.sspl_.data()
);
}
Expand All @@ -2950,31 +2950,31 @@ void Model::fdwdx(realtype const t, realtype const* x, bool include_static) {

derived_state_.dwdx_.zero();
if (pythonGenerated) {
if (!dwdx_hierarchical_.at(0).capacity())
if (!derived_state_.dwdx_hierarchical_.at(0).capacity())
return;

fdwdw(t, x, include_static);

if (include_static) {
dwdx_hierarchical_.at(0).zero();
fdwdx_colptrs(dwdx_hierarchical_.at(0));
fdwdx_rowvals(dwdx_hierarchical_.at(0));
derived_state_.dwdx_hierarchical_.at(0).zero();
fdwdx_colptrs(derived_state_.dwdx_hierarchical_.at(0));
fdwdx_rowvals(derived_state_.dwdx_hierarchical_.at(0));
}
fdwdx(
dwdx_hierarchical_.at(0).data(), t, x,
derived_state_.dwdx_hierarchical_.at(0).data(), t, x,
state_.unscaledParameters.data(), state_.fixedParameters.data(),
state_.h.data(), derived_state_.w_.data(), state_.total_cl.data(),
state_.spl_.data(), include_static
derived_state_.spl_.data(), include_static
);

for (int irecursion = 1; irecursion <= w_recursion_depth_;
irecursion++) {
dwdw_.sparse_multiply(
dwdx_hierarchical_.at(irecursion),
dwdx_hierarchical_.at(irecursion - 1)
derived_state_.dwdw_.sparse_multiply(
derived_state_.dwdx_hierarchical_.at(irecursion),
derived_state_.dwdx_hierarchical_.at(irecursion - 1)
);
}
derived_state_.dwdx_.sparse_sum(dwdx_hierarchical_);
derived_state_.dwdx_.sparse_sum(derived_state_.dwdx_hierarchical_);

} else {
if (!derived_state_.dwdx_.capacity())
Expand All @@ -2983,7 +2983,7 @@ void Model::fdwdx(realtype const t, realtype const* x, bool include_static) {
fdwdx(
derived_state_.dwdx_.data(), t, x, state_.unscaledParameters.data(),
state_.fixedParameters.data(), state_.h.data(),
derived_state_.w_.data(), state_.total_cl.data(), state_.spl_.data()
derived_state_.w_.data(), state_.total_cl.data(), derived_state_.spl_.data()
);
}

Expand All @@ -2997,19 +2997,19 @@ void Model::fdwdw(realtype const t, realtype const* x, bool include_static) {
return;

if (include_static) {
dwdw_.zero();
fdwdw_colptrs(dwdw_);
fdwdw_rowvals(dwdw_);
derived_state_.dwdw_.zero();
fdwdw_colptrs(derived_state_.dwdw_);
fdwdw_rowvals(derived_state_.dwdw_);
}

fdwdw(
dwdw_.data(), t, x, state_.unscaledParameters.data(),
derived_state_.dwdw_.data(), t, x, state_.unscaledParameters.data(),
state_.fixedParameters.data(), state_.h.data(),
derived_state_.w_.data(), state_.total_cl.data(), include_static
);

if (always_check_finite_) {
checkFinite(dwdw_, ModelQuantity::dwdw, t);
checkFinite(derived_state_.dwdw_, ModelQuantity::dwdw, t);
}
}

Expand Down

0 comments on commit 6b59490

Please sign in to comment.