Skip to content

Commit

Permalink
Refactor ModelDimension and initialization of ModelStateDerived (#2507)
Browse files Browse the repository at this point in the history
* Refactor ModelDimension and initialization of ModelStateDerived

* Move some other model dimensions from the Model constructor to ModelDimensions
* Perform more initialization of ModelStateDerived inside its constructor

* dJydy_dense_ -> ModelStateDerived
  • Loading branch information
dweindl authored Sep 27, 2024
1 parent de1804a commit 823b6bd
Show file tree
Hide file tree
Showing 8 changed files with 111 additions and 107 deletions.
14 changes: 1 addition & 13 deletions include/amici/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,20 +112,14 @@ class Model : public AbstractModel, public ModelDimensions {
* @param o2mode Second order sensitivity mode
* @param idlist Indexes indicating algebraic components (DAE only)
* @param z2event Mapping of event outputs to events
* @param pythonGenerated Flag indicating matlab or python wrapping
* @param ndxdotdp_explicit Number of nonzero elements in `dxdotdp_explicit`
* @param ndxdotdx_explicit Number of nonzero elements in `dxdotdx_explicit`
* @param w_recursion_depth Recursion depth of fw
* @param state_independent_events Map of events with state-independent
* triggers functions, mapping trigger timepoints to event indices.
*/
Model(
ModelDimensions const& model_dimensions,
SimulationParameters simulation_parameters,
amici::SecondOrderMode o2mode, std::vector<amici::realtype> idlist,
std::vector<int> z2event, bool pythonGenerated = false,
int ndxdotdp_explicit = 0, int ndxdotdx_explicit = 0,
int w_recursion_depth = 0,
std::vector<int> z2event,
std::map<realtype, std::vector<int>> state_independent_events = {}
);

Expand Down Expand Up @@ -1458,9 +1452,6 @@ class Model : public AbstractModel, public ModelDimensions {
*/
std::vector<int> const& getReinitializationStateIdxs() const;

/** Flag indicating Matlab- or Python-based model generation */
bool pythonGenerated = false;

/**
* @brief getter for dxdotdp (matlab generated)
* @return dxdotdp
Expand Down Expand Up @@ -2098,9 +2089,6 @@ class Model : public AbstractModel, public ModelDimensions {
realtype min_sigma_{50.0};

private:
/** Recursion */
int w_recursion_depth_{0};

/** Simulation parameters, initial state, etc. */
SimulationParameters simulation_parameters_;

Expand Down
11 changes: 2 additions & 9 deletions include/amici/model_dae.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,26 +36,19 @@ class Model_DAE : public Model {
* @param o2mode second order sensitivity mode
* @param idlist indexes indicating algebraic components (DAE only)
* @param z2event mapping of event outputs to events
* @param pythonGenerated flag indicating matlab or python wrapping
* @param ndxdotdp_explicit number of nonzero elements dxdotdp_explicit
* @param ndxdotdx_explicit number of nonzero elements dxdotdx_explicit
* @param w_recursion_depth Recursion depth of fw
* @param state_independent_events Map of events with state-independent
* triggers functions, mapping trigger timepoints to event indices.
*/
Model_DAE(
ModelDimensions const& model_dimensions,
SimulationParameters simulation_parameters,
SecondOrderMode const o2mode, std::vector<realtype> const& idlist,
std::vector<int> const& z2event, bool const pythonGenerated = false,
int const ndxdotdp_explicit = 0, int const ndxdotdx_explicit = 0,
int const w_recursion_depth = 0,
std::vector<int> const& z2event,
std::map<realtype, std::vector<int>> state_independent_events = {}
)
: Model(
model_dimensions, simulation_parameters, o2mode, idlist, z2event,
pythonGenerated, ndxdotdp_explicit, ndxdotdx_explicit,
w_recursion_depth, state_independent_events
state_independent_events
) {
derived_state_.M_ = SUNMatrixWrapper(nx_solver, nx_solver);
auto M_nnz = static_cast<sunindextype>(
Expand Down
31 changes: 28 additions & 3 deletions include/amici/model_dimensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace amici {
/**
* @brief Container for model dimensions.
*
* Holds number of states, observables, etc.
* Holds number of state variables, observables, etc.
*/
struct ModelDimensions {
/** Default ctor */
Expand Down Expand Up @@ -54,6 +54,11 @@ struct ModelDimensions {
* @param nnz Number of nonzero elements in Jacobian
* @param ubw Upper matrix bandwidth in the Jacobian
* @param lbw Lower matrix bandwidth in the Jacobian
* @param pythonGenerated Flag indicating model creation from Matlab or
* Python
* @param ndxdotdp_explicit Number of nonzero elements in `dxdotdp_explicit`
* @param ndxdotdx_explicit Number of nonzero elements in `dxdotdx_explicit`
* @param w_recursion_depth Recursion depth of fw
*/
ModelDimensions(
int const nx_rdata, int const nxtrue_rdata, int const nx_solver,
Expand All @@ -64,7 +69,8 @@ struct ModelDimensions {
int const ndwdw, int const ndxdotdw, std::vector<int> ndJydy,
int const ndxrdatadxsolver, int const ndxrdatadtcl,
int const ndtotal_cldx_rdata, int const nnz, int const ubw,
int const lbw
int const lbw, bool pythonGenerated = false, int ndxdotdp_explicit = 0,
int ndxdotdx_explicit = 0, int w_recursion_depth = 0
)
: nx_rdata(nx_rdata)
, nxtrue_rdata(nxtrue_rdata)
Expand Down Expand Up @@ -92,7 +98,11 @@ struct ModelDimensions {
, nnz(nnz)
, nJ(nJ)
, ubw(ubw)
, lbw(lbw) {
, lbw(lbw)
, pythonGenerated(pythonGenerated)
, ndxdotdp_explicit(ndxdotdp_explicit)
, ndxdotdx_explicit(ndxdotdx_explicit)
, w_recursion_depth(w_recursion_depth) {
Expects(nxtrue_rdata >= 0);
Expects(nxtrue_rdata <= nx_rdata);
Expects(nxtrue_solver >= 0);
Expand Down Expand Up @@ -128,6 +138,9 @@ struct ModelDimensions {
Expects(nJ >= 0);
Expects(ubw >= 0);
Expects(lbw >= 0);
Expects(ndxdotdp_explicit >= 0);
Expects(ndxdotdx_explicit >= 0);
Expects(w_recursion_depth >= 0);
}

/** Number of states */
Expand Down Expand Up @@ -229,6 +242,18 @@ struct ModelDimensions {

/** Lower bandwidth of the Jacobian */
int lbw{0};

/** Flag indicating model creation from Matlab or Python */
bool pythonGenerated = false;

/** Number of nonzero elements in `dxdotdx_explicit` */
int ndxdotdp_explicit = 0;

/** Number of nonzero elements in `dxdotdp_explicit` */
int ndxdotdx_explicit = 0;

/** Recursion depth of fw */
int w_recursion_depth = 0;
};

} // namespace amici
Expand Down
11 changes: 2 additions & 9 deletions include/amici/model_ode.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,26 +35,19 @@ class Model_ODE : public Model {
* @param o2mode second order sensitivity mode
* @param idlist indexes indicating algebraic components (DAE only)
* @param z2event mapping of event outputs to events
* @param pythonGenerated flag indicating matlab or python wrapping
* @param ndxdotdp_explicit number of nonzero elements dxdotdp_explicit
* @param ndxdotdx_explicit number of nonzero elements dxdotdx_explicit
* @param w_recursion_depth Recursion depth of fw
* @param state_independent_events Map of events with state-independent
* triggers functions, mapping trigger timepoints to event indices.
*/
Model_ODE(
ModelDimensions const& model_dimensions,
SimulationParameters simulation_parameters,
SecondOrderMode const o2mode, std::vector<realtype> const& idlist,
std::vector<int> const& z2event, bool const pythonGenerated = false,
int const ndxdotdp_explicit = 0, int const ndxdotdx_explicit = 0,
int const w_recursion_depth = 0,
std::vector<int> const& z2event,
std::map<realtype, std::vector<int>> state_independent_events = {}
)
: Model(
model_dimensions, simulation_parameters, o2mode, idlist, z2event,
pythonGenerated, ndxdotdp_explicit, ndxdotdx_explicit,
w_recursion_depth, state_independent_events
state_independent_events
) {}

void
Expand Down
3 changes: 3 additions & 0 deletions include/amici/model_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,9 @@ struct ModelStateDerived {

/** Sparse dwdx implicit temporary storage (shape `ndwdx`) */
std::vector<SUNMatrixWrapper> dwdx_hierarchical_;

/** Temporary storage for dense dJydy (dimension: `nJ` x `ny`) */
SUNMatrixWrapper dJydy_dense_;
};

/**
Expand Down
77 changes: 10 additions & 67 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,19 +177,15 @@ Model::Model(
ModelDimensions const& model_dimensions,
SimulationParameters simulation_parameters, SecondOrderMode o2mode,
std::vector<realtype> idlist, std::vector<int> z2event,
bool const pythonGenerated, int const ndxdotdp_explicit,
int const ndxdotdx_explicit, int const w_recursion_depth,
std::map<realtype, std::vector<int>> state_independent_events
)
: ModelDimensions(model_dimensions)
, pythonGenerated(pythonGenerated)
, o2mode(o2mode)
, idlist(std::move(idlist))
, state_independent_events_(std::move(state_independent_events))
, derived_state_(model_dimensions)
, z2event_(std::move(z2event))
, state_is_non_negative_(nx_solver, false)
, w_recursion_depth_(w_recursion_depth)
, simulation_parameters_(std::move(simulation_parameters)) {
Expects(
model_dimensions.np
Expand Down Expand Up @@ -217,60 +213,6 @@ Model::Model(

root_initial_values_.resize(ne, true);

/* If Matlab wrapped: dxdotdp is a full AmiVector,
if Python wrapped: dxdotdp_explicit and dxdotdp_implicit are CSC matrices
*/
if (pythonGenerated) {

derived_state_.dwdw_ = SUNMatrixWrapper(nw, nw, ndwdw, CSC_MAT);
// size dynamically adapted for dwdx_ and dwdp_
derived_state_.dwdx_ = SUNMatrixWrapper(nw, nx_solver, 0, CSC_MAT);
derived_state_.dwdp_ = SUNMatrixWrapper(nw, np(), 0, CSC_MAT);

for (int irec = 0; irec <= w_recursion_depth_; ++irec) {
/* for the first element we know the exact size, while for all
others we guess the size*/
derived_state_.dwdp_hierarchical_.emplace_back(
SUNMatrixWrapper(nw, np(), irec * ndwdw + ndwdp, CSC_MAT)
);
derived_state_.dwdx_hierarchical_.emplace_back(
SUNMatrixWrapper(nw, nx_solver, irec * ndwdw + ndwdx, CSC_MAT)
);
}
assert(
gsl::narrow<int>(derived_state_.dwdp_hierarchical_.size())
== w_recursion_depth_ + 1
);
assert(
gsl::narrow<int>(derived_state_.dwdx_hierarchical_.size())
== w_recursion_depth_ + 1
);

derived_state_.dxdotdp_explicit
= SUNMatrixWrapper(nx_solver, np(), ndxdotdp_explicit, CSC_MAT);
// guess size, will be dynamically reallocated
derived_state_.dxdotdp_implicit
= SUNMatrixWrapper(nx_solver, np(), ndwdp + ndxdotdw, CSC_MAT);
derived_state_.dxdotdx_explicit = SUNMatrixWrapper(
nx_solver, nx_solver, ndxdotdx_explicit, CSC_MAT
);
// guess size, will be dynamically reallocated
derived_state_.dxdotdx_implicit
= SUNMatrixWrapper(nx_solver, nx_solver, ndwdx + ndxdotdw, CSC_MAT);
// dynamically allocate on first call
derived_state_.dxdotdp_full
= SUNMatrixWrapper(nx_solver, np(), 0, CSC_MAT);

for (int iytrue = 0; iytrue < nytrue; ++iytrue)
derived_state_.dJydy_.emplace_back(
SUNMatrixWrapper(nJ, ny, ndJydy.at(iytrue), CSC_MAT)
);
} else {
derived_state_.dwdx_ = SUNMatrixWrapper(nw, nx_solver, ndwdx, CSC_MAT);
derived_state_.dwdp_ = SUNMatrixWrapper(nw, np(), ndwdp, CSC_MAT);
derived_state_.dJydy_matlab_
= std::vector<realtype>(nJ * nytrue * ny, 0.0);
}
requireSensitivitiesForAllParameters();
}

Expand Down Expand Up @@ -2250,7 +2192,6 @@ void Model::fdJydy(int const it, AmiVector const& x, ExpData const& edata) {
if (pythonGenerated) {
fdJydsigma(it, x, edata);
fdsigmaydy(it, &edata);
SUNMatrixWrapper tmp_dense(nJ, ny);

setNaNtoZero(derived_state_.dJydsigma_);
setNaNtoZero(derived_state_.dsigmaydy_);
Expand All @@ -2275,15 +2216,15 @@ void Model::fdJydy(int const it, AmiVector const& x, ExpData const& edata) {
// dJydy += dJydsigma * dsigmaydy
// C(nJ,ny) A(nJ,ny) * B(ny,ny)
// sparse dense dense
tmp_dense.zero();
derived_state_.dJydy_dense_.zero();
amici_dgemm(
BLASLayout::colMajor, BLASTranspose::noTrans,
BLASTranspose::noTrans, nJ, ny, ny, 1.0,
&derived_state_.dJydsigma_.at(iyt * nJ * ny), nJ,
derived_state_.dsigmaydy_.data(), ny, 1.0, tmp_dense.data(), nJ
derived_state_.dsigmaydy_.data(), ny, 1.0, derived_state_.dJydy_dense_.data(), nJ
);

auto tmp_sparse = SUNMatrixWrapper(tmp_dense, 0.0, CSC_MAT);
auto tmp_sparse = SUNMatrixWrapper(derived_state_.dJydy_dense_, 0.0, CSC_MAT);
auto ret = SUNMatScaleAdd(
1.0, derived_state_.dJydy_.at(iyt), tmp_sparse
);
Expand Down Expand Up @@ -2871,8 +2812,9 @@ void Model::fsspl(realtype const t) {
realtype* sspl_data = derived_state_.sspl_.data();
for (int ip = 0; ip < nplist(); ip++) {
for (int ispl = 0; ispl < nspl; ispl++)
sspl_data[ispl + nspl * plist(ip)]
= splines_[ispl].get_sensitivity(t, ip, derived_state_.spl_[ispl]);
sspl_data[ispl + nspl * plist(ip)] = splines_[ispl].get_sensitivity(
t, ip, derived_state_.spl_[ispl]
);
}
}

Expand Down Expand Up @@ -2914,7 +2856,7 @@ void Model::fdwdp(realtype const t, realtype const* x, bool include_static) {
derived_state_.sspl_.data(), include_static
);

for (int irecursion = 1; irecursion <= w_recursion_depth_;
for (int irecursion = 1; irecursion <= w_recursion_depth;
irecursion++) {
derived_state_.dwdw_.sparse_multiply(
derived_state_.dwdp_hierarchical_.at(irecursion),
Expand Down Expand Up @@ -2966,7 +2908,7 @@ void Model::fdwdx(realtype const t, realtype const* x, bool include_static) {
derived_state_.spl_.data(), include_static
);

for (int irecursion = 1; irecursion <= w_recursion_depth_;
for (int irecursion = 1; irecursion <= w_recursion_depth;
irecursion++) {
derived_state_.dwdw_.sparse_multiply(
derived_state_.dwdx_hierarchical_.at(irecursion),
Expand All @@ -2982,7 +2924,8 @@ void Model::fdwdx(realtype const t, realtype const* x, bool include_static) {
fdwdx(
derived_state_.dwdx_.data(), t, x, state_.unscaledParameters.data(),
state_.fixedParameters.data(), state_.h.data(),
derived_state_.w_.data(), state_.total_cl.data(), derived_state_.spl_.data()
derived_state_.w_.data(), state_.total_cl.data(),
derived_state_.spl_.data()
);
}

Expand Down
10 changes: 5 additions & 5 deletions src/model_header.template.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,11 @@ class Model_TPL_MODELNAME : public amici::Model_TPL_MODEL_TYPE_UPPER {
TPL_NDTOTALCLDXRDATA, // ndtotal_cldx_rdata
0, // nnz
TPL_UBW, // ubw
TPL_LBW // lbw
TPL_LBW, // lbw
true, // pythonGenerated
TPL_NDXDOTDP_EXPLICIT, // ndxdotdp_explicit
TPL_NDXDOTDX_EXPLICIT, // ndxdotdx_explicit
TPL_W_RECURSION_DEPTH // w_recursion_depth
),
amici::SimulationParameters(
std::vector<realtype>{TPL_FIXED_PARAMETERS}, // fixedParameters
Expand All @@ -144,10 +148,6 @@ class Model_TPL_MODELNAME : public amici::Model_TPL_MODEL_TYPE_UPPER {
TPL_O2MODE, // o2mode
std::vector<realtype>{TPL_ID}, // idlist
std::vector<int>{TPL_Z2EVENT}, // z2events
true, // pythonGenerated
TPL_NDXDOTDP_EXPLICIT, // ndxdotdp_explicit
TPL_NDXDOTDX_EXPLICIT, // ndxdotdx_explicit
TPL_W_RECURSION_DEPTH, // w_recursion_depth
{TPL_STATE_INDEPENDENT_EVENTS} // state-independent events
) {
root_initial_values_ = std::vector<bool>(
Expand Down
Loading

0 comments on commit 823b6bd

Please sign in to comment.