Skip to content

Commit

Permalink
State reinitialisation in JAX (#2619)
Browse files Browse the repository at this point in the history
* disentangle sim & preeq

* disentangle sim & preeq

* run preequilibration once

* fix symlink

* separate default dirs for jax/cpp, honour model dir/name

* fix notebook

* fix path SNAFU

* fix models without preequilibration

* fix tests

* fixup

* fix doc typehints

* fix notebook

* implement jax-based reinitialisation

* add more defaults & doc

* fix state ids

* fix template

* Update model.py

* breaking jax release

* add jax runner to petab testsuite & fix

* fix notebook

* refactor petab test cases

* fix parameter unscaling

* fixups

* refactor & simplify

* fixup

* fix notebook

* fixup

* Update petab.py
  • Loading branch information
FFroehlich authored Dec 8, 2024
1 parent 449041d commit f9e64de
Show file tree
Hide file tree
Showing 12 changed files with 470 additions and 124 deletions.
7 changes: 4 additions & 3 deletions python/examples/example_jax_petab/ExampleJaxPEtab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -352,12 +352,13 @@
"source": [
"import jax.numpy as jnp\n",
"import diffrax\n",
"from amici.jax import ReturnValue\n",
"\n",
"# Define the simulation condition\n",
"simulation_condition = (\"model1_data1\",)\n",
"\n",
"# Load condition-specific data\n",
"ts_init, ts_dyn, ts_posteq, my, iys = jax_problem._measurements[\n",
"ts_init, ts_dyn, ts_posteq, my, iys, iy_trafos = jax_problem._measurements[\n",
" simulation_condition\n",
"]\n",
"\n",
Expand All @@ -375,12 +376,12 @@
" ts_posteq=ts_posteq,\n",
" my=jnp.array(my),\n",
" iys=jnp.array(iys),\n",
" x_preeq=jnp.array([]),\n",
" iy_trafos=jnp.array(iy_trafos),\n",
" solver=diffrax.Kvaerno5(),\n",
" controller=diffrax.PIDController(atol=1e-8, rtol=1e-8),\n",
" max_steps=2**10,\n",
" adjoint=diffrax.DirectAdjoint(),\n",
" ret=\"y\", # Return observables\n",
" ret=ReturnValue.y, # Return observables\n",
" )[0]\n",
"\n",
"\n",
Expand Down
15 changes: 13 additions & 2 deletions python/sdist/amici/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@

from warnings import warn

from amici.jax.petab import JAXProblem, run_simulations
from amici.jax.petab import (
JAXProblem,
run_simulations,
petab_simulate,
ReturnValue,
)
from amici.jax.model import JAXModel

warn(
Expand All @@ -18,4 +23,10 @@
stacklevel=2,
)

__all__ = ["JAXModel", "JAXProblem", "run_simulations"]
__all__ = [
"JAXModel",
"JAXProblem",
"run_simulations",
"petab_simulate",
"ReturnValue",
]
2 changes: 1 addition & 1 deletion python/sdist/amici/jax/jax.template.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def observable_ids(self):

@property
def state_ids(self):
return TPL_X_IDS
return TPL_X_RDATA_IDS

@property
def parameter_ids(self):
Expand Down
135 changes: 93 additions & 42 deletions python/sdist/amici/jax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from abc import abstractmethod
from pathlib import Path
import enum

import diffrax
import equinox as eqx
Expand All @@ -12,6 +13,20 @@
import jaxtyping as jt


class ReturnValue(enum.Enum):
llh = "log-likelihood"
nllhs = "pointwise negative log-likelihood"
x0 = "full initial state vector"
x0_solver = "reduced initial state vector"
x = "full state vector"
x_solver = "reduced state vector"
y = "observables"
sigmay = "standard deviations of the observables"
tcl = "total values for conservation laws"
res = "residuals"
chi2 = "sum(((observed - simulated) / sigma ) ** 2)"


class JAXModel(eqx.Module):
"""
JAXModel provides an abstract base class for a JAX-based implementation of an AMICI model. The class implements
Expand Down Expand Up @@ -432,12 +447,15 @@ def simulate_condition(
ts_posteq: jt.Float[jt.Array, "nt_posteq"],
my: jt.Float[jt.Array, "nt"],
iys: jt.Int[jt.Array, "nt"],
x_preeq: jt.Float[jt.Array, "nx"],
iy_trafos: jt.Int[jt.Array, "nt"],
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
adjoint: diffrax.AbstractAdjoint,
max_steps: int | jnp.int_,
ret: str = "llh",
x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]),
mask_reinit: jt.Bool[jt.Array, "*nx"] = jnp.array([]),
x_reinit: jt.Float[jt.Array, "*nx"] = jnp.array([]),
ret: ReturnValue = ReturnValue.llh,
) -> tuple[jt.Float[jt.Array, "nt *nx"] | jnp.float_, dict]:
r"""
Simulate a condition.
Expand All @@ -458,6 +476,13 @@ def simulate_condition(
observed data
:param iys:
indices of the observables according to ordering in :ivar observable_ids:
:param x_preeq:
initial state vector for pre-equilibration. If not provided, the initial state vector is computed using
:meth:`_x0`.
:param mask_reinit:
mask for re-initialization. If `True`, the corresponding state variable is re-initialized.
:param x_reinit:
re-initialized state vector. If not provided, the state vector is not re-initialized.
:param solver:
ODE solver
:param controller:
Expand All @@ -468,90 +493,114 @@ def simulate_condition(
:param max_steps:
maximum number of solver steps
:param ret:
which output to return. Valid values are
- `llh`: log-likelihood (default)
- `nllhs`: negative log-likelihood at each time point
- `x0`: full initial state vector (after pre-equilibration)
- `x0_solver`: reduced initial state vector (after pre-equilibration)
- `x`: full state vector
- `x_solver`: reduced state vector
- `y`: observables
- `sigmay`: standard deviations of the observables
- `tcl`: total values for conservation laws (at final timepoint)
- `res`: residuals (observed - simulated)
which output to return. See :class:`ReturnValue` for available options.
:return:
output according to `ret` and statistics
"""
# Pre-equilibration
if x_preeq.shape[0] > 0:
current_x = self._x_solver(x_preeq)
# update tcl with new parameters
tcl = self._tcl(x_preeq, p)
if x_preeq.shape[0]:
x = x_preeq
else:
x0 = self._x0(p)
current_x = self._x_solver(x0)
x = self._x0(p)

# Re-initialization
if x_reinit.shape[0]:
x = jnp.where(mask_reinit, x_reinit, x)
x_solver = self._x_solver(x)
tcl = self._tcl(x, p)

tcl = self._tcl(x0, p)
x_preq = jnp.repeat(current_x.reshape(1, -1), ts_init.shape[0], axis=0)
x_preq = jnp.repeat(x_solver.reshape(1, -1), ts_init.shape[0], axis=0)

# Dynamic simulation
if ts_dyn.shape[0] > 0:
if ts_dyn.shape[0]:
x_dyn, stats_dyn = self._solve(
p,
ts_dyn,
tcl,
current_x,
x_solver,
solver,
controller,
max_steps,
adjoint,
)
current_x = x_dyn[-1, :]
x_solver = x_dyn[-1, :]
else:
x_dyn = jnp.repeat(
current_x.reshape(1, -1), ts_dyn.shape[0], axis=0
x_solver.reshape(1, -1), ts_dyn.shape[0], axis=0
)
stats_dyn = None

# Post-equilibration
if ts_posteq.shape[0] > 0:
current_x, stats_posteq = self._eq(
p, tcl, current_x, solver, controller, max_steps
if ts_posteq.shape[0]:
x_solver, stats_posteq = self._eq(
p, tcl, x_solver, solver, controller, max_steps
)
else:
stats_posteq = None

x_posteq = jnp.repeat(
current_x.reshape(1, -1), ts_posteq.shape[0], axis=0
x_solver.reshape(1, -1), ts_posteq.shape[0], axis=0
)

ts = jnp.concatenate((ts_init, ts_dyn, ts_posteq), axis=0)
x = jnp.concatenate((x_preq, x_dyn, x_posteq), axis=0)

nllhs = self._nllhs(ts, x, p, tcl, my, iys)
llh = -jnp.sum(nllhs)
return {
"llh": llh,
"nllhs": nllhs,
"x": self._x_rdatas(x, tcl),
"x_solver": x,
"y": self._ys(ts, x, p, tcl, iys),
"sigmay": self._sigmays(ts, x, p, tcl, iys),
"x0": self._x_rdata(x[0, :], tcl),
"x0_solver": x[0, :],
"tcl": tcl,
"res": self._ys(ts, x, p, tcl, iys) - my,
}[ret], dict(

stats = dict(
ts=ts,
x=x,
llh=llh,
stats_dyn=stats_dyn,
stats_posteq=stats_posteq,
)
if ret == ReturnValue.llh:
output = llh
elif ret == ReturnValue.nllhs:
output = nllhs
elif ret == ReturnValue.x:
output = self._x_rdatas(x, tcl)
elif ret == ReturnValue.x_solver:
output = x
elif ret == ReturnValue.y:
output = self._ys(ts, x, p, tcl, iys)
elif ret == ReturnValue.sigmay:
output = self._sigmays(ts, x, p, tcl, iys)
elif ret == ReturnValue.x0:
output = self._x_rdata(x[0, :], tcl)
elif ret == ReturnValue.x0_solver:
output = x[0, :]
elif ret == ReturnValue.tcl:
output = tcl
elif ret in (ReturnValue.res, ReturnValue.chi2):
obs_trafo = jax.vmap(
lambda y, iy_trafo: jnp.array(
# needs to follow order in amici.jax.petab.SCALE_TO_INT
[y, safe_log(y), safe_log(y) / jnp.log(10)]
)
.at[iy_trafo]
.get(),
)
ys_obj = obs_trafo(self._ys(ts, x, p, tcl, iys), iy_trafos)
m_obj = obs_trafo(my, iy_trafos)
if ret == ReturnValue.chi2:
output = jnp.sum(
jnp.square(ys_obj - m_obj)
/ jnp.square(self._sigmays(ts, x, p, tcl, iys))
)
else:
output = ys_obj - m_obj
else:
raise NotImplementedError(f"Return value {ret} not implemented.")

return output, stats

@eqx.filter_jit
def preequilibrate_condition(
self,
p: jt.Float[jt.Array, "np"],
x_reinit: jt.Float[jt.Array, "*nx"],
mask_reinit: jt.Bool[jt.Array, "*nx"],
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
max_steps: int | jnp.int_,
Expand All @@ -572,6 +621,8 @@ def preequilibrate_condition(
"""
# Pre-equilibration
x0 = self._x0(p)
if x_reinit.shape[0]:
x0 = jnp.where(mask_reinit, x_reinit, x0)
tcl = self._tcl(x0, p)
current_x = self._x_solver(x0)
current_x, stats_preeq = self._eq(
Expand Down
2 changes: 1 addition & 1 deletion python/sdist/amici/jax/ode_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def _generate_jax_code(self) -> None:
# assign named variables from a jax array
**_jax_variable_assignments(self.model, sym_names),
# tuple of variable names (ids as they are unique)
**_jax_variable_ids(self.model, ("p", "k", "y", "x")),
**_jax_variable_ids(self.model, ("p", "k", "y", "x_rdata")),
**{
"MODEL_NAME": self.model_name,
# keep track of the API version that the model was generated with so we
Expand Down
Loading

0 comments on commit f9e64de

Please sign in to comment.