Skip to content

Commit

Permalink
ReturnData: remove protected data members (#2503)
Browse files Browse the repository at this point in the history
Removes most protected  data members from `ReturnData`, in particular sundials objects. 

Makes the code slightly easier to follow and we won't have to worry about managing SUNContext outside of runAmiciSimulation with sundials>=6.0.

Also saves some unnecessary copying of various data.

Closes #2348
  • Loading branch information
dweindl authored Sep 26, 2024
1 parent 57d0c29 commit 6312f6c
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 143 deletions.
7 changes: 4 additions & 3 deletions include/amici/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -1427,19 +1427,20 @@ class Model : public AbstractModel, public ModelDimensions {
* @param x_solver State variables with conservation laws applied
* (solver returns this)
*/
void fx_rdata(AmiVector& x_rdata, AmiVector const& x_solver);
void fx_rdata(gsl::span<realtype> x_rdata, AmiVector const& x_solver);

/**
* @brief Expand conservation law for state sensitivities.
* @param sx_rdata Output buffer for state variables sensitivities with
* conservation laws expanded (stored in `amici::ReturnData`).
* conservation laws expanded
* (stored in `amici::ReturnData` shape `nplist` x `nx`, row-major).
* @param sx_solver State variables sensitivities with conservation laws
* applied (solver returns this)
* @param x_solver State variables with conservation laws
* applied (solver returns this)
*/
void fsx_rdata(
AmiVectorArray& sx_rdata, AmiVectorArray const& sx_solver,
gsl::span<realtype> sx_rdata, AmiVectorArray const& sx_solver,
AmiVector const& x_solver
);

Expand Down
14 changes: 11 additions & 3 deletions include/amici/model_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -332,11 +332,19 @@ struct ModelStateDerived {
struct SimulationState {
/** timepoint */
realtype t;
/** state variables */
/**
* partial state vector, excluding states eliminated from conservation laws
*/
AmiVector x;
/** state variables */
/**
* partial time derivative of state vector, excluding states eliminated
* from conservation laws
*/
AmiVector dx;
/** state variable sensitivity */
/**
* partial sensitivity state vector array, excluding states eliminated from
* conservation laws
*/
AmiVectorArray sx;
/** state of the model that was used for simulation */
ModelState state;
Expand Down
117 changes: 61 additions & 56 deletions include/amici/rdata.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,12 @@ class ReturnData : public ModelDimensions {
*/
std::vector<realtype> ts;

/** time derivative (shape `nx`) evaluated at `t_last`. */
/** time derivative (shape `nx_solver`) evaluated at `t_last`. */
std::vector<realtype> xdot;

/**
* Jacobian of differential equation right hand side (shape `nx` x `nx`,
* row-major) evaluated at `t_last`.
* Jacobian of differential equation right hand side (shape `nx_solver` x
* `nx_solver`, row-major) evaluated at `t_last`.
*/
std::vector<realtype> J;

Expand Down Expand Up @@ -156,11 +156,12 @@ class ReturnData : public ModelDimensions {
*/
std::vector<realtype> s2rz;

/** state (shape `nt` x `nx`, row-major) */
/** state (shape `nt` x `nx_rdata`, row-major) */
std::vector<realtype> x;

/**
* parameter derivative of state (shape `nt` x `nplist` x `nx`, row-major)
* parameter derivative of state (shape `nt` x `nplist` x `nx_rdata`,
* row-major)
*/
std::vector<realtype> sx;

Expand Down Expand Up @@ -358,18 +359,18 @@ class ReturnData : public ModelDimensions {
*/
realtype posteq_wrms = NAN;

/** initial state (shape `nx`) */
/** initial state (shape `nx_rdata`) */
std::vector<realtype> x0;

/** preequilibration steady state (shape `nx`) */
/** preequilibration steady state (shape `nx_rdata`) */
std::vector<realtype> x_ss;

/** initial sensitivities (shape `nplist` x `nx`, row-major) */
/** initial sensitivities (shape `nplist` x `nx_rdata`, row-major) */
std::vector<realtype> sx0;

/**
* preequilibration sensitivities
* (shape `nplist` x `nx`, row-major)
* (shape `nplist` x `nx_rdata`, row-major)
*/
std::vector<realtype> sx_ss;

Expand Down Expand Up @@ -463,28 +464,6 @@ class ReturnData : public ModelDimensions {
/** offset for sigma_residuals */
realtype sigma_offset;

/** timepoint for model evaluation*/
realtype t_;

/** partial state vector, excluding states eliminated from conservation laws
*/
AmiVector x_solver_;

/** partial time derivative of state vector, excluding states eliminated
* from conservation laws */
AmiVector dx_solver_;

/** partial sensitivity state vector array, excluding states eliminated from
* conservation laws */
AmiVectorArray sx_solver_;

/** full state vector, including states eliminated from conservation laws */
AmiVector x_rdata_;

/** full sensitivity state vector array, including states eliminated from
* conservation laws */
AmiVectorArray sx_rdata_;

/** array of number of found roots for a certain event type
* (shape `ne`) */
std::vector<int> nroots_;
Expand Down Expand Up @@ -568,40 +547,44 @@ class ReturnData : public ModelDimensions {
template <class T>
void
storeJacobianAndDerivativeInReturnData(T const& problem, Model& model) {
readSimulationState(problem.getFinalSimulationState(), model);
auto simulation_state = problem.getFinalSimulationState();
model.setModelState(simulation_state.state);

AmiVector xdot(nx_solver);
if (!this->xdot.empty() || !this->J.empty())
model.fxdot(t_, x_solver_, dx_solver_, xdot);
model.fxdot(
simulation_state.t, simulation_state.x, simulation_state.dx,
xdot
);

if (!this->xdot.empty())
writeSlice(xdot, this->xdot);

if (!this->J.empty()) {
SUNMatrixWrapper J(nx_solver, nx_solver);
model.fJ(t_, 0.0, x_solver_, dx_solver_, xdot, J);
model.fJ(
simulation_state.t, 0.0, simulation_state.x,
simulation_state.dx, xdot, J
);
// CVODES uses colmajor, so we need to transform to rowmajor
for (int ix = 0; ix < model.nx_solver; ix++)
for (int jx = 0; jx < model.nx_solver; jx++)
this->J.at(ix * model.nx_solver + jx)
= J.data()[ix + model.nx_solver * jx];
}
}
/**
* @brief sets member variables and model state according to provided
* simulation state
* @param state simulation state provided by Problem
* @param model model that was used for forward/backward simulation
*/
void readSimulationState(SimulationState const& state, Model& model);

/**
* @brief Residual function
* @param it time index
* @param model model that was used for forward/backward simulation
* @param simulation_state simulation state the timepoint `it`
* @param edata ExpData instance containing observable data
*/
void fres(int it, Model& model, ExpData const& edata);
void fres(
int it, Model& model, SimulationState const& simulation_state,
ExpData const& edata
);

/**
* @brief Chi-squared function
Expand All @@ -614,17 +597,25 @@ class ReturnData : public ModelDimensions {
* @brief Residual sensitivity function
* @param it time index
* @param model model that was used for forward/backward simulation
* @param simulation_state simulation state the timepoint `it`
* @param edata ExpData instance containing observable data
*/
void fsres(int it, Model& model, ExpData const& edata);
void fsres(
int it, Model& model, SimulationState const& simulation_state,
ExpData const& edata
);

/**
* @brief Fisher information matrix function
* @param it time index
* @param model model that was used for forward/backward simulation
* @param simulation_state simulation state the timepoint `it`
* @param edata ExpData instance containing observable data
*/
void fFIM(int it, Model& model, ExpData const& edata);
void fFIM(
int it, Model& model, SimulationState const& simulation_state,
ExpData const& edata
);

/**
* @brief Set likelihood, state variables, outputs and respective
Expand Down Expand Up @@ -665,46 +656,58 @@ class ReturnData : public ModelDimensions {

/**
* @brief Extracts output information for data-points, expects that
* x_solver_ and sx_solver_ were set appropriately
* the model state was set appropriately
* @param it timepoint index
* @param model model that was used in forward solve
* @param simulation_state simulation state the timepoint `it`
* @param edata ExpData instance carrying experimental data
*/
void getDataOutput(int it, Model& model, ExpData const* edata);
void getDataOutput(
int it, Model& model, SimulationState const& simulation_state,
ExpData const* edata
);

/**
* @brief Extracts data information for forward sensitivity analysis,
* expects that x_solver_ and sx_solver_ were set appropriately
* expects that the model state was set appropriately
* @param it index of current timepoint
* @param model model that was used in forward solve
* @param simulation_state simulation state the timepoint `it`
* @param edata ExpData instance carrying experimental data
*/
void getDataSensisFSA(int it, Model& model, ExpData const* edata);
void getDataSensisFSA(
int it, Model& model, SimulationState const& simulation_state,
ExpData const* edata
);

/**
* @brief Extracts output information for events, expects that x_solver_
* and sx_solver_ were set appropriately
* @brief Extracts output information for events, expects that the model
* state was set appropriately
* @param t event timepoint
* @param rootidx information about which roots fired
* (1 indicating fired, 0/-1 for not)
* @param model model that was used in forward solve
* @param simulation_state simulation state the timepoint `it`
* @param edata ExpData instance carrying experimental data
*/
void getEventOutput(
realtype t, std::vector<int> const rootidx, Model& model,
ExpData const* edata
realtype t, std::vector<int> const& rootidx, Model& model,
SimulationState const& simulation_state, ExpData const* edata
);

/**
* @brief Extracts event information for forward sensitivity analysis,
* expects that x_solver_ and sx_solver_ were set appropriately
* expects the model state was set appropriately
* @param ie index of event type
* @param t event timepoint
* @param model model that was used in forward solve
* @param simulation_state simulation state the timepoint `it`
* @param edata ExpData instance carrying experimental data
*/
void
getEventSensisFSA(int ie, realtype t, Model& model, ExpData const* edata);
void getEventSensisFSA(
int ie, realtype t, Model& model,
SimulationState const& simulation_state, ExpData const* edata
);

/**
* @brief Updates contribution to likelihood from quadratures (xQB),
Expand All @@ -725,12 +728,14 @@ class ReturnData : public ModelDimensions {
* (llhS0), if no preequilibration was run or if forward sensitivities were
* used
* @param model model that was used for forward/backward simulation
* @param simulation_state simulation state the timepoint `it`
* @param llhS0 contribution to likelihood for initial state sensitivities
* @param xB vector with final adjoint state
* (excluding conservation laws)
*/
void handleSx0Forward(
Model const& model, std::vector<realtype>& llhS0, AmiVector& xB
Model const& model, SimulationState const& simulation_state,
std::vector<realtype>& llhS0, AmiVector& xB
) const;
};

Expand Down
8 changes: 4 additions & 4 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1985,28 +1985,28 @@ void Model::fsx0_fixedParameters(AmiVectorArray& sx, AmiVector const& x) {

void Model::fsdx0() {}

void Model::fx_rdata(AmiVector& x_rdata, AmiVector const& x) {
void Model::fx_rdata(gsl::span<realtype> x_rdata, AmiVector const& x) {
fx_rdata(
x_rdata.data(), computeX_pos(x), state_.total_cl.data(),
state_.unscaledParameters.data(), state_.fixedParameters.data()
);
if (always_check_finite_)
checkFinite(
x_rdata.getVector(), ModelQuantity::x_rdata,
x_rdata, ModelQuantity::x_rdata,
std::numeric_limits<realtype>::quiet_NaN()
);
}

void Model::fsx_rdata(
AmiVectorArray& sx_rdata, AmiVectorArray const& sx,
gsl::span<realtype> sx_rdata, AmiVectorArray const& sx,
AmiVector const& x_solver
) {
realtype* stcl = nullptr;
for (int ip = 0; ip < nplist(); ip++) {
if (ncl() > 0)
stcl = &state_.stotal_cl.at(plist(ip) * ncl());
fsx_rdata(
sx_rdata.data(ip), sx.data(ip), stcl,
&sx_rdata[ip * nx_rdata], sx.data(ip), stcl,
state_.unscaledParameters.data(), state_.fixedParameters.data(),
x_solver.data(), state_.total_cl.data(), plist(ip)
);
Expand Down
Loading

0 comments on commit 6312f6c

Please sign in to comment.