Skip to content

Commit

Permalink
Merge branch 'develop' into jax_sciml
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich authored Dec 19, 2024
2 parents 731b925 + 1716181 commit 2ec62a5
Show file tree
Hide file tree
Showing 12 changed files with 200 additions and 44 deletions.
1 change: 1 addition & 0 deletions include/amici/defines.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ enum class RDataReporting {
full,
residuals,
likelihood,
observables_likelihood,
};

/** boundary conditions for splines */
Expand Down
7 changes: 7 additions & 0 deletions include/amici/rdata.h
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,13 @@ class ReturnData : public ModelDimensions {
*/
void initializeLikelihoodReporting(bool quadratic_llh);

/**
* @brief initializes storage for observables + likelihood reporting mode
* @param quadratic_llh whether model defines a quadratic nllh and computing
* res, sres and FIM makes sense.
*/
void initializeObservablesLikelihoodReporting(bool quadratic_llh);

/**
* @brief initializes storage for residual reporting mode
* @param enable_res whether residuals are to be computed
Expand Down
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,7 @@ filterwarnings =
ignore:.*PyDevIPCompleter6.*:DeprecationWarning
# ignore numpy log(0) warnings (np.log(0) = -inf)
ignore:divide by zero encountered in log:RuntimeWarning
# ignore jax deprecation warnings
ignore:jax.* is deprecated:DeprecationWarning

norecursedirs = .git amici_models build doc documentation matlab models ThirdParty amici sdist examples
4 changes: 2 additions & 2 deletions python/examples/example_jax_petab/ExampleJaxPEtab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@
"simulation_condition = (\"model1_data1\",)\n",
"\n",
"# Load condition-specific data\n",
"ts_init, ts_dyn, ts_posteq, my, iys, iy_trafos = jax_problem._measurements[\n",
"ts_dyn, ts_posteq, my, iys, iy_trafos = jax_problem._measurements[\n",
" simulation_condition\n",
"]\n",
"\n",
Expand All @@ -378,14 +378,14 @@
"def grad_ts_dyn(tt):\n",
" return jax_problem.model.simulate_condition(\n",
" p=p,\n",
" ts_init=ts_init,\n",
" ts_dyn=tt,\n",
" ts_posteq=ts_posteq,\n",
" my=jnp.array(my),\n",
" iys=jnp.array(iys),\n",
" iy_trafos=jnp.array(iy_trafos),\n",
" solver=diffrax.Kvaerno5(),\n",
" controller=diffrax.PIDController(atol=1e-8, rtol=1e-8),\n",
" steady_state_event=diffrax.steady_state_event(),\n",
" max_steps=2**10,\n",
" adjoint=diffrax.DirectAdjoint(),\n",
" ret=ReturnValue.y, # Return observables\n",
Expand Down
67 changes: 52 additions & 15 deletions python/sdist/amici/jax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import jax
import jaxtyping as jt

from collections.abc import Callable


class ReturnValue(enum.Enum):
llh = "log-likelihood"
Expand All @@ -32,6 +34,13 @@ class JAXModel(eqx.Module):
JAXModel provides an abstract base class for a JAX-based implementation of an AMICI model. The class implements
routines for simulation and evaluation of derived quantities, model specific implementations need to be provided by
classes inheriting from JAXModel.
:ivar api_version:
API version of the derived class. Needs to match the API version of the base class (MODEL_API_VERSION).
:ivar MODEL_API_VERSION:
API version of the base class.
:ivar jax_py_file:
Path to the JAX model file.
"""

MODEL_API_VERSION = "0.0.2"
Expand Down Expand Up @@ -249,6 +258,9 @@ def _eq(
x0: jt.Float[jt.Array, "nxs"],
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
steady_state_event: Callable[
..., diffrax._custom_types.BoolScalarLike
],
max_steps: jnp.int_,
) -> tuple[jt.Float[jt.Array, "1 nxs"], dict]:
"""
Expand Down Expand Up @@ -279,10 +291,20 @@ def _eq(
stepsize_controller=controller,
max_steps=max_steps,
adjoint=diffrax.DirectAdjoint(),
event=diffrax.Event(cond_fn=diffrax.steady_state_event()),
event=diffrax.Event(
cond_fn=steady_state_event,
),
throw=False,
)
return sol.ys[-1, :], sol.stats
# If the event was triggered, the event mask is True and the solution is the steady state. Otherwise, the
# solution is the last state and the event mask is False. In the latter case, we return inf for the steady
# state.
ys = jnp.where(
sol.event_mask,
sol.ys[-1, :],
jnp.inf * jnp.ones_like(sol.ys[-1, :]),
)
return ys, sol.stats

def _solve(
self,
Expand Down Expand Up @@ -443,7 +465,6 @@ def _sigmays(
def simulate_condition(
self,
p: jt.Float[jt.Array, "np"],
ts_init: jt.Float[jt.Array, "nt_preeq"],
ts_dyn: jt.Float[jt.Array, "nt_dyn"],
ts_posteq: jt.Float[jt.Array, "nt_posteq"],
my: jt.Float[jt.Array, "nt"],
Expand All @@ -452,6 +473,9 @@ def simulate_condition(
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
adjoint: diffrax.AbstractAdjoint,
steady_state_event: Callable[
..., diffrax._custom_types.BoolScalarLike
],
max_steps: int | jnp.int_,
x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]),
mask_reinit: jt.Bool[jt.Array, "*nx"] = jnp.array([]),
Expand All @@ -463,13 +487,9 @@ def simulate_condition(
:param p:
parameters for simulation ordered according to ids in :ivar parameter_ids:
:param ts_init:
time points that do not require simulation. Usually valued 0.0, but needs to be shaped according to
the number of observables that are evaluated before dynamic simulation.
:param ts_dyn:
time points for dynamic simulation. Usually valued > 0.0 and sorted in monotonically increasing order.
Duplicate time points are allowed to facilitate the evaluation of multiple observables at specific time
points.
time points for dynamic simulation. Sorted in monotonically increasing order but duplicate time points are
allowed to facilitate the evaluation of multiple observables at specific time points.
:param ts_posteq:
time points for post-equilibration. Usually valued \Infty, but needs to be shaped according to
the number of observables that are evaluated after post-equilibration.
Expand Down Expand Up @@ -509,8 +529,6 @@ def simulate_condition(
x_solver = self._x_solver(x)
tcl = self._tcl(x, p)

x_preq = jnp.repeat(x_solver.reshape(1, -1), ts_init.shape[0], axis=0)

# Dynamic simulation
if ts_dyn.shape[0]:
x_dyn, stats_dyn = self._solve(
Expand All @@ -533,7 +551,13 @@ def simulate_condition(
# Post-equilibration
if ts_posteq.shape[0]:
x_solver, stats_posteq = self._eq(
p, tcl, x_solver, solver, controller, max_steps
p,
tcl,
x_solver,
solver,
controller,
steady_state_event,
max_steps,
)
else:
stats_posteq = None
Expand All @@ -542,8 +566,8 @@ def simulate_condition(
x_solver.reshape(1, -1), ts_posteq.shape[0], axis=0
)

ts = jnp.concatenate((ts_init, ts_dyn, ts_posteq), axis=0)
x = jnp.concatenate((x_preq, x_dyn, x_posteq), axis=0)
ts = jnp.concatenate((ts_dyn, ts_posteq), axis=0)
x = jnp.concatenate((x_dyn, x_posteq), axis=0)

nllhs = self._nllhs(ts, x, p, tcl, my, iys)
llh = -jnp.sum(nllhs)
Expand Down Expand Up @@ -604,13 +628,20 @@ def preequilibrate_condition(
mask_reinit: jt.Bool[jt.Array, "*nx"],
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
steady_state_event: Callable[
..., diffrax._custom_types.BoolScalarLike
],
max_steps: int | jnp.int_,
) -> tuple[jt.Float[jt.Array, "nx"], dict]:
r"""
Simulate a condition.
:param p:
parameters for simulation ordered according to ids in :ivar parameter_ids:
:param x_reinit:
re-initialized state vector. If not provided, the state vector is not re-initialized.
:param mask_reinit:
mask for re-initialization. If `True`, the corresponding state variable is re-initialized.
:param solver:
ODE solver
:param controller:
Expand All @@ -627,7 +658,13 @@ def preequilibrate_condition(
tcl = self._tcl(x0, p)
current_x = self._x_solver(x0)
current_x, stats_preeq = self._eq(
p, tcl, current_x, solver, controller, max_steps
p,
tcl,
current_x,
solver,
controller,
steady_state_event,
max_steps,
)

return self._x_rdata(current_x, tcl), dict(stats_preeq=stats_preeq)
Expand Down
44 changes: 30 additions & 14 deletions python/sdist/amici/jax/petab.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from numbers import Number
from collections.abc import Iterable
from pathlib import Path
from collections.abc import Callable


import diffrax
Expand Down Expand Up @@ -71,7 +72,7 @@ class JAXProblem(eqx.Module):
:ivar _parameter_mappings:
:class:`ParameterMappingForCondition` instances for each simulation condition.
:ivar _measurements:
Subset measurement dataframes for each simulation condition.
Preprocessed arrays for each simulation condition.
:ivar _petab_problem:
PEtab problem to simulate.
"""
Expand All @@ -87,7 +88,6 @@ class JAXProblem(eqx.Module):
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
],
]
_inputs: dict[str, dict[str, np.ndarray]]
Expand Down Expand Up @@ -188,7 +188,6 @@ def _get_measurements(
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
],
],
dict[tuple[str, ...], tuple[int, ...]],
Expand All @@ -214,11 +213,9 @@ def _get_measurements(
)

ts = m[petab.TIME]
ts_preeq = ts[np.isfinite(ts) & (ts == 0)]
ts_dyn = ts[np.isfinite(ts) & (ts > 0)]
ts_dyn = ts[np.isfinite(ts)]
ts_posteq = ts[np.logical_not(np.isfinite(ts))]
index = pd.concat([ts_preeq, ts_dyn, ts_posteq]).index
ts_preeq = ts_preeq.values
index = pd.concat([ts_dyn, ts_posteq]).index
ts_dyn = ts_dyn.values
ts_posteq = ts_posteq.values
my = m[petab.MEASUREMENT].values
Expand Down Expand Up @@ -246,7 +243,6 @@ def _get_measurements(
iy_trafos = np.zeros_like(iys)

measurements[tuple(simulation_condition)] = (
ts_preeq,
ts_dyn,
ts_posteq,
my,
Expand Down Expand Up @@ -600,6 +596,9 @@ def run_simulation(
simulation_condition: tuple[str, ...],
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
steady_state_event: Callable[
..., diffrax._custom_types.BoolScalarLike
],
max_steps: jnp.int_,
x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]), # noqa: F821, F722
ret: ReturnValue = ReturnValue.llh,
Expand All @@ -622,16 +621,15 @@ def run_simulation(
:return:
Tuple of output value and simulation statistics
"""
ts_preeq, ts_dyn, ts_posteq, my, iys, iy_trafos = self._measurements[
ts_dyn, ts_posteq, my, iys, iy_trafos = self._measurements[
simulation_condition
]
p = self.load_model_parameters(simulation_condition[0])
mask_reinit, x_reinit = self.load_reinitialisation(
simulation_condition[0], p
)
return self.model.simulate_condition(
p=eqx.debug.backward_nan(p),
ts_init=jax.lax.stop_gradient(jnp.array(ts_preeq)),
p=p,
ts_dyn=jax.lax.stop_gradient(jnp.array(ts_dyn)),
ts_posteq=jax.lax.stop_gradient(jnp.array(ts_posteq)),
my=jax.lax.stop_gradient(jnp.array(my)),
Expand All @@ -643,6 +641,7 @@ def run_simulation(
solver=solver,
controller=controller,
max_steps=max_steps,
steady_state_event=steady_state_event,
adjoint=diffrax.RecursiveCheckpointAdjoint()
if ret in (ReturnValue.llh, ReturnValue.chi2)
else diffrax.DirectAdjoint(),
Expand All @@ -654,6 +653,9 @@ def run_preequilibration(
simulation_condition: str,
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
steady_state_event: Callable[
..., diffrax._custom_types.BoolScalarLike
],
max_steps: jnp.int_,
) -> tuple[jt.Float[jt.Array, "nx"], dict]: # noqa: F821
"""
Expand All @@ -675,12 +677,13 @@ def run_preequilibration(
simulation_condition, p
)
return self.model.preequilibrate_condition(
p=eqx.debug.backward_nan(p),
p=p,
mask_reinit=mask_reinit,
x_reinit=x_reinit,
solver=solver,
controller=controller,
max_steps=max_steps,
steady_state_event=steady_state_event,
)


Expand All @@ -691,6 +694,9 @@ def run_simulations(
controller: diffrax.AbstractStepSizeController = diffrax.PIDController(
**DEFAULT_CONTROLLER_SETTINGS
),
steady_state_event: Callable[
..., diffrax._custom_types.BoolScalarLike
] = diffrax.steady_state_event(),
max_steps: int = 2**10,
ret: ReturnValue | str = ReturnValue.llh,
):
Expand All @@ -705,6 +711,9 @@ def run_simulations(
ODE solver to use for simulation.
:param controller:
Step size controller to use for simulation.
:param steady_state_event:
Steady state event function to use for pre-/post-equilibration. Allows customisation of the steady state
condition, see :func:`diffrax.steady_state_event` for details.
:param max_steps:
Maximum number of steps to take during simulation.
:param ret:
Expand All @@ -719,7 +728,9 @@ def run_simulations(
simulation_conditions = problem.get_all_simulation_conditions()

preeqs = {
sc: problem.run_preequilibration(sc, solver, controller, max_steps)
sc: problem.run_preequilibration(
sc, solver, controller, steady_state_event, max_steps
)
# only run preequilibration once per condition
for sc in {sc[1] for sc in simulation_conditions if len(sc) > 1}
}
Expand All @@ -729,6 +740,7 @@ def run_simulations(
sc,
solver,
controller,
steady_state_event,
max_steps,
preeqs.get(sc[1])[0] if len(sc) > 1 else jnp.array([]),
ret=ret,
Expand All @@ -753,6 +765,9 @@ def petab_simulate(
controller: diffrax.AbstractStepSizeController = diffrax.PIDController(
**DEFAULT_CONTROLLER_SETTINGS
),
steady_state_event: Callable[
..., diffrax._custom_types.BoolScalarLike
] = diffrax.steady_state_event(),
max_steps: int = 2**10,
):
"""
Expand All @@ -773,14 +788,15 @@ def petab_simulate(
problem,
solver=solver,
controller=controller,
steady_state_event=steady_state_event,
max_steps=max_steps,
ret=ReturnValue.y,
)
dfs = []
for sc, ys in y.items():
obs = [
problem.model.observable_ids[io]
for io in problem._measurements[sc][4]
for io in problem._measurements[sc][3]
]
t = jnp.concat(problem._measurements[sc][:2])
df_sc = pd.DataFrame(
Expand Down
Loading

0 comments on commit 2ec62a5

Please sign in to comment.