Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Handle events occuring at fixed timepoints without root-finding
Browse files Browse the repository at this point in the history
A first attempt towards AMICI-dev#2185

For events that occur at known timepoints, we don't need sundials'
root-finding. We can just stop the solver at the respective timepoints
and handle the events.

To be extended to parameterized but state-independent trigger functions.
dweindl committed Dec 4, 2023

Verified

This commit was signed with the committer’s verified signature.
henrikvtcodes Henrik VT
1 parent 81872cc commit baa13af
Showing 26 changed files with 234 additions and 43 deletions.
4 changes: 3 additions & 1 deletion include/amici/forwardproblem.h
Original file line number Diff line number Diff line change
@@ -197,7 +197,9 @@ class ForwardProblem {
SimulationState const& getSimulationStateTimepoint(int it) const {
if (model->getTimepoint(it) == initial_state_.t)
return getInitialSimulationState();
return timepoint_states_.find(model->getTimepoint(it))->second;
auto map_iter = timepoint_states_.find(model->getTimepoint(it));
assert(map_iter != timepoint_states_.end());
return map_iter->second;
};

/**
10 changes: 8 additions & 2 deletions include/amici/model.h
Original file line number Diff line number Diff line change
@@ -12,7 +12,6 @@
#include "amici/vector.h"

#include <map>
#include <memory>
#include <vector>

namespace amici {
@@ -117,14 +116,17 @@ class Model : public AbstractModel, public ModelDimensions {
* @param ndxdotdp_explicit Number of nonzero elements in `dxdotdp_explicit`
* @param ndxdotdx_explicit Number of nonzero elements in `dxdotdx_explicit`
* @param w_recursion_depth Recursion depth of fw
* @param state_independent_events Map of events with state-independent
* triggers functions, mapping trigger timepoints to event indices.
*/
Model(
ModelDimensions const& model_dimensions,
SimulationParameters simulation_parameters,
amici::SecondOrderMode o2mode, std::vector<amici::realtype> idlist,
std::vector<int> z2event, bool pythonGenerated = false,
int ndxdotdp_explicit = 0, int ndxdotdx_explicit = 0,
int w_recursion_depth = 0
int w_recursion_depth = 0,
std::map<realtype, std::vector<int>> state_independent_events = {}
);

/** Destructor. */
@@ -1449,6 +1451,8 @@ class Model : public AbstractModel, public ModelDimensions {
*/
SUNMatrixWrapper const& get_dxdotdp_full() const;

virtual std::vector<double> get_trigger_timepoints() const;

/**
* Flag indicating whether for
* `amici::Solver::sensi_` == `amici::SensitivityOrder::second`
@@ -1462,6 +1466,8 @@ class Model : public AbstractModel, public ModelDimensions {
/** Logger */
Logger* logger = nullptr;

std::map<realtype, std::vector<int>> state_independent_events_ = {};

protected:
/**
* @brief Write part of a slice to a buffer according to indices specified
20 changes: 14 additions & 6 deletions include/amici/model_dimensions.h
Original file line number Diff line number Diff line change
@@ -31,6 +31,7 @@ struct ModelDimensions {
* @param nz Number of event observables
* @param nztrue Number of event observables of the non-augmented model
* @param ne Number of events
* @param ne_solver Number of events that require root-finding
* @param nspl Number of splines
* @param nJ Number of objective functions
* @param nw Number of repeating elements
@@ -58,11 +59,12 @@ struct ModelDimensions {
int const nx_rdata, int const nxtrue_rdata, int const nx_solver,
int const nxtrue_solver, int const nx_solver_reinit, int const np,
int const nk, int const ny, int const nytrue, int const nz,
int const nztrue, int const ne, int const nspl, int const nJ,
int const nw, int const ndwdx, int const ndwdp, int const ndwdw,
int const ndxdotdw, std::vector<int> ndJydy, int const ndxrdatadxsolver,
int const ndxrdatadtcl, int const ndtotal_cldx_rdata, int const nnz,
int const ubw, int const lbw
int const nztrue, int const ne, int const ne_solver, int const nspl,
int const nJ, int const nw, int const ndwdx, int const ndwdp,
int const ndwdw, int const ndxdotdw, std::vector<int> ndJydy,
int const ndxrdatadxsolver, int const ndxrdatadtcl,
int const ndtotal_cldx_rdata, int const nnz, int const ubw,
int const lbw
)
: nx_rdata(nx_rdata)
, nxtrue_rdata(nxtrue_rdata)
@@ -76,6 +78,7 @@ struct ModelDimensions {
, nz(nz)
, nztrue(nztrue)
, ne(ne)
, ne_solver(ne_solver)
, nspl(nspl)
, nw(nw)
, ndwdx(ndwdx)
@@ -104,6 +107,8 @@ struct ModelDimensions {
Expects(nztrue >= 0);
Expects(nztrue <= nz);
Expects(ne >= 0);
Expects(ne_solver >= 0);
Expects(ne >= ne_solver);
Expects(nspl >= 0);
Expects(nw >= 0);
Expects(ndwdx >= 0);
@@ -164,7 +169,10 @@ struct ModelDimensions {
/** Number of events */
int ne{0};

/** numer of spline functions in the model */
/** Number of events that require root-finding */
int ne_solver{0};

/** Number of spline functions in the model */
int nspl{0};

/** Number of common expressions */
8 changes: 6 additions & 2 deletions include/amici/model_ode.h
Original file line number Diff line number Diff line change
@@ -39,19 +39,23 @@ class Model_ODE : public Model {
* @param ndxdotdp_explicit number of nonzero elements dxdotdp_explicit
* @param ndxdotdx_explicit number of nonzero elements dxdotdx_explicit
* @param w_recursion_depth Recursion depth of fw
* @param state_independent_events Map of events with state-independent
* triggers functions, mapping trigger timepoints to event indices.
*/
Model_ODE(
ModelDimensions const& model_dimensions,
SimulationParameters simulation_parameters,
const SecondOrderMode o2mode, std::vector<realtype> const& idlist,
std::vector<int> const& z2event, bool const pythonGenerated = false,
int const ndxdotdp_explicit = 0, int const ndxdotdx_explicit = 0,
int const w_recursion_depth = 0
int const w_recursion_depth = 0,
std::map<realtype, std::vector<int>> state_independent_events
= {}
)
: Model(
model_dimensions, simulation_parameters, o2mode, idlist, z2event,
pythonGenerated, ndxdotdp_explicit, ndxdotdx_explicit,
w_recursion_depth
w_recursion_depth, state_independent_events
) {}

void
1 change: 1 addition & 0 deletions include/amici/serialization.h
Original file line number Diff line number Diff line change
@@ -260,6 +260,7 @@ void serialize(
ar& m.nz;
ar& m.nztrue;
ar& m.ne;
ar& m.ne_solver;
ar& m.nspl;
ar& m.nw;
ar& m.ndwdx;
3 changes: 2 additions & 1 deletion models/model_calvetti/model_calvetti.h
Original file line number Diff line number Diff line change
@@ -46,6 +46,7 @@ class Model_model_calvetti : public amici::Model_DAE {
0,
4,
0,
0,
1,
38,
53,
@@ -207,6 +208,6 @@ class Model_model_calvetti : public amici::Model_DAE {

} // namespace model_model_calvetti

} // namespace amici
} // namespace amici

#endif /* _amici_model_calvetti_h */
3 changes: 2 additions & 1 deletion models/model_dirac/model_dirac.h
Original file line number Diff line number Diff line change
@@ -46,6 +46,7 @@ class Model_model_dirac : public amici::Model_ODE {
0,
2,
0,
0,
1,
0,
0,
@@ -204,6 +205,6 @@ class Model_model_dirac : public amici::Model_ODE {

} // namespace model_model_dirac

} // namespace amici
} // namespace amici

#endif /* _amici_model_dirac_h */
3 changes: 2 additions & 1 deletion models/model_events/model_events.h
Original file line number Diff line number Diff line change
@@ -60,6 +60,7 @@ class Model_model_events : public amici::Model_ODE {
2,
6,
0,
0,
1,
0,
0,
@@ -232,6 +233,6 @@ class Model_model_events : public amici::Model_ODE {

} // namespace model_model_events

} // namespace amici
} // namespace amici

#endif /* _amici_model_events_h */
3 changes: 2 additions & 1 deletion models/model_jakstat_adjoint/model_jakstat_adjoint.h
Original file line number Diff line number Diff line change
@@ -49,6 +49,7 @@ class Model_model_jakstat_adjoint : public amici::Model_ODE {
0,
0,
0,
0,
1,
2,
1,
@@ -210,6 +211,6 @@ class Model_model_jakstat_adjoint : public amici::Model_ODE {

} // namespace model_model_jakstat_adjoint

} // namespace amici
} // namespace amici

#endif /* _amici_model_jakstat_adjoint_h */
3 changes: 2 additions & 1 deletion models/model_jakstat_adjoint_o2/model_jakstat_adjoint_o2.h
Original file line number Diff line number Diff line change
@@ -49,6 +49,7 @@ class Model_model_jakstat_adjoint_o2 : public amici::Model_ODE {
0,
0,
0,
0,
18,
10,
2,
@@ -210,6 +211,6 @@ class Model_model_jakstat_adjoint_o2 : public amici::Model_ODE {

} // namespace model_model_jakstat_adjoint_o2

} // namespace amici
} // namespace amici

#endif /* _amici_model_jakstat_adjoint_o2_h */
3 changes: 2 additions & 1 deletion models/model_nested_events/model_nested_events.h
Original file line number Diff line number Diff line change
@@ -49,6 +49,7 @@ class Model_model_nested_events : public amici::Model_ODE {
0,
4,
0,
0,
1,
0,
0,
@@ -210,6 +211,6 @@ class Model_model_nested_events : public amici::Model_ODE {

} // namespace model_model_nested_events

} // namespace amici
} // namespace amici

#endif /* _amici_model_nested_events_h */
3 changes: 2 additions & 1 deletion models/model_neuron/model_neuron.h
Original file line number Diff line number Diff line change
@@ -63,6 +63,7 @@ class Model_model_neuron : public amici::Model_ODE {
1,
1,
0,
0,
1,
0,
0,
@@ -238,6 +239,6 @@ class Model_model_neuron : public amici::Model_ODE {

} // namespace model_model_neuron

} // namespace amici
} // namespace amici

#endif /* _amici_model_neuron_h */
3 changes: 2 additions & 1 deletion models/model_neuron_o2/model_neuron_o2.h
Original file line number Diff line number Diff line change
@@ -65,6 +65,7 @@ class Model_model_neuron_o2 : public amici::Model_ODE {
1,
1,
0,
0,
5,
2,
2,
@@ -242,6 +243,6 @@ class Model_model_neuron_o2 : public amici::Model_ODE {

} // namespace model_model_neuron_o2

} // namespace amici
} // namespace amici

#endif /* _amici_model_neuron_o2_h */
3 changes: 2 additions & 1 deletion models/model_robertson/model_robertson.h
Original file line number Diff line number Diff line change
@@ -47,6 +47,7 @@ class Model_model_robertson : public amici::Model_DAE {
0,
0,
0,
0,
1,
1,
2,
@@ -209,6 +210,6 @@ class Model_model_robertson : public amici::Model_DAE {

} // namespace model_model_robertson

} // namespace amici
} // namespace amici

#endif /* _amici_model_robertson_h */
3 changes: 2 additions & 1 deletion models/model_steadystate/model_steadystate.h
Original file line number Diff line number Diff line change
@@ -46,6 +46,7 @@ class Model_model_steadystate : public amici::Model_ODE {
0,
0,
0,
0,
1,
2,
2,
@@ -204,6 +205,6 @@ class Model_model_steadystate : public amici::Model_ODE {

} // namespace model_model_steadystate

} // namespace amici
} // namespace amici

#endif /* _amici_model_steadystate_h */
44 changes: 43 additions & 1 deletion python/sdist/amici/de_export.py
Original file line number Diff line number Diff line change
@@ -1425,13 +1425,24 @@ def num_expr(self) -> int:
return len(self.sym("w"))

def num_events(self) -> int:
"""
Total number of Events (those for which root-functions are added and those without).
:return:
number of events
"""
return len(self.sym("h"))

def num_events_solver(self) -> int:
"""
Number of Events.
:return:
number of event symbols (length of the root vector in AMICI)
"""
return len(self.sym("h"))
return sum(
not event.triggers_at_fixed_timepoint() for event in self.events()
)

def sym(self, name: str) -> sp.Matrix:
"""
@@ -1750,6 +1761,16 @@ def parse_events(self) -> None:
# add roots of heaviside functions
self.add_component(root)

# re-order events - first those that require root tracking, then the others
self._events = list(
chain(
itertools.filterfalse(
Event.triggers_at_fixed_timepoint, self._events
),
filter(Event.triggers_at_fixed_timepoint, self._events),
)
)

def get_appearance_counts(self, idxs: List[int]) -> List[int]:
"""
Counts how often a state appears in the time derivative of
@@ -3642,6 +3663,7 @@ def _write_model_header_cpp(self) -> None:
"NZ": self.model.num_eventobs(),
"NZTRUE": self.model.num_eventobs(),
"NEVENT": self.model.num_events(),
"NEVENT_SOLVER": self.model.num_events_solver(),
"NOBJECTIVE": "1",
"NSPL": len(self.model.splines),
"NW": len(self.model.sym("w")),
@@ -3736,6 +3758,7 @@ def _write_model_header_cpp(self) -> None:
)
),
"Z2EVENT": ", ".join(map(str, self.model._z2event)),
"STATE_INDEPENDENT_EVENTS": self._get_state_independent_event_intializer(),
"ID": ", ".join(
(
str(float(isinstance(s, DifferentialState)))
@@ -3871,6 +3894,25 @@ def _get_symbol_id_initializer_list(self, name: str) -> str:
for idx, symbol in enumerate(self.model.sym(name))
)

def _get_state_independent_event_intializer(self) -> str:
tmp_map = {}
for event_idx, event in enumerate(self.model.events()):
if not event.triggers_at_fixed_timepoint():
continue
trigger_time = float(event.get_trigger_time())
try:
tmp_map[trigger_time].append(event_idx)
except KeyError:
tmp_map[trigger_time] = [event_idx]

def vector_initializer(v):
return f"{{{', '.join(map(str, v))}}}"

return ", ".join(
f"{{{trigger_time}, {vector_initializer(event_idxs)}}}"
for trigger_time, event_idxs in tmp_map.items()
)

def _write_c_make_file(self):
"""Write CMake ``CMakeLists.txt`` file for this model."""
sources = "\n".join(
Loading

0 comments on commit baa13af

Please sign in to comment.