Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

State reinitialisation in JAX #2619

Merged
merged 29 commits into from
Dec 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions python/examples/example_jax_petab/ExampleJaxPEtab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -352,12 +352,13 @@
"source": [
"import jax.numpy as jnp\n",
"import diffrax\n",
"from amici.jax import ReturnValue\n",
"\n",
"# Define the simulation condition\n",
"simulation_condition = (\"model1_data1\",)\n",
"\n",
"# Load condition-specific data\n",
"ts_init, ts_dyn, ts_posteq, my, iys = jax_problem._measurements[\n",
"ts_init, ts_dyn, ts_posteq, my, iys, iy_trafos = jax_problem._measurements[\n",
" simulation_condition\n",
"]\n",
"\n",
Expand All @@ -375,12 +376,12 @@
" ts_posteq=ts_posteq,\n",
" my=jnp.array(my),\n",
" iys=jnp.array(iys),\n",
" x_preeq=jnp.array([]),\n",
" iy_trafos=jnp.array(iy_trafos),\n",
" solver=diffrax.Kvaerno5(),\n",
" controller=diffrax.PIDController(atol=1e-8, rtol=1e-8),\n",
" max_steps=2**10,\n",
" adjoint=diffrax.DirectAdjoint(),\n",
" ret=\"y\", # Return observables\n",
" ret=ReturnValue.y, # Return observables\n",
" )[0]\n",
"\n",
"\n",
Expand Down
15 changes: 13 additions & 2 deletions python/sdist/amici/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@

from warnings import warn

from amici.jax.petab import JAXProblem, run_simulations
from amici.jax.petab import (
JAXProblem,
run_simulations,
petab_simulate,
ReturnValue,
)
from amici.jax.model import JAXModel

warn(
Expand All @@ -18,4 +23,10 @@
stacklevel=2,
)

__all__ = ["JAXModel", "JAXProblem", "run_simulations"]
__all__ = [
"JAXModel",
"JAXProblem",
"run_simulations",
"petab_simulate",
"ReturnValue",
]
2 changes: 1 addition & 1 deletion python/sdist/amici/jax/jax.template.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def observable_ids(self):

@property
def state_ids(self):
return TPL_X_IDS
return TPL_X_RDATA_IDS

@property
def parameter_ids(self):
Expand Down
135 changes: 93 additions & 42 deletions python/sdist/amici/jax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from abc import abstractmethod
from pathlib import Path
import enum

import diffrax
import equinox as eqx
Expand All @@ -12,6 +13,20 @@
import jaxtyping as jt


class ReturnValue(enum.Enum):
llh = "log-likelihood"
nllhs = "pointwise negative log-likelihood"
x0 = "full initial state vector"
x0_solver = "reduced initial state vector"
x = "full state vector"
x_solver = "reduced state vector"
y = "observables"
sigmay = "standard deviations of the observables"
tcl = "total values for conservation laws"
res = "residuals"
chi2 = "sum(((observed - simulated) / sigma ) ** 2)"


class JAXModel(eqx.Module):
"""
JAXModel provides an abstract base class for a JAX-based implementation of an AMICI model. The class implements
Expand Down Expand Up @@ -432,12 +447,15 @@
ts_posteq: jt.Float[jt.Array, "nt_posteq"],
my: jt.Float[jt.Array, "nt"],
iys: jt.Int[jt.Array, "nt"],
x_preeq: jt.Float[jt.Array, "nx"],
iy_trafos: jt.Int[jt.Array, "nt"],
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
adjoint: diffrax.AbstractAdjoint,
max_steps: int | jnp.int_,
ret: str = "llh",
x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]),
mask_reinit: jt.Bool[jt.Array, "*nx"] = jnp.array([]),
x_reinit: jt.Float[jt.Array, "*nx"] = jnp.array([]),
ret: ReturnValue = ReturnValue.llh,
) -> tuple[jt.Float[jt.Array, "nt *nx"] | jnp.float_, dict]:
r"""
Simulate a condition.
Expand All @@ -458,6 +476,13 @@
observed data
:param iys:
indices of the observables according to ordering in :ivar observable_ids:
:param x_preeq:
initial state vector for pre-equilibration. If not provided, the initial state vector is computed using
:meth:`_x0`.
:param mask_reinit:
mask for re-initialization. If `True`, the corresponding state variable is re-initialized.
:param x_reinit:
re-initialized state vector. If not provided, the state vector is not re-initialized.
:param solver:
ODE solver
:param controller:
Expand All @@ -468,90 +493,114 @@
:param max_steps:
maximum number of solver steps
:param ret:
which output to return. Valid values are
- `llh`: log-likelihood (default)
- `nllhs`: negative log-likelihood at each time point
- `x0`: full initial state vector (after pre-equilibration)
- `x0_solver`: reduced initial state vector (after pre-equilibration)
- `x`: full state vector
- `x_solver`: reduced state vector
- `y`: observables
- `sigmay`: standard deviations of the observables
- `tcl`: total values for conservation laws (at final timepoint)
- `res`: residuals (observed - simulated)
which output to return. See :class:`ReturnValue` for available options.
:return:
output according to `ret` and statistics
"""
# Pre-equilibration
if x_preeq.shape[0] > 0:
current_x = self._x_solver(x_preeq)
# update tcl with new parameters
tcl = self._tcl(x_preeq, p)
if x_preeq.shape[0]:
x = x_preeq
else:
x0 = self._x0(p)
current_x = self._x_solver(x0)
x = self._x0(p)

# Re-initialization
if x_reinit.shape[0]:
x = jnp.where(mask_reinit, x_reinit, x)
x_solver = self._x_solver(x)
tcl = self._tcl(x, p)

tcl = self._tcl(x0, p)
x_preq = jnp.repeat(current_x.reshape(1, -1), ts_init.shape[0], axis=0)
x_preq = jnp.repeat(x_solver.reshape(1, -1), ts_init.shape[0], axis=0)

# Dynamic simulation
if ts_dyn.shape[0] > 0:
if ts_dyn.shape[0]:
x_dyn, stats_dyn = self._solve(
p,
ts_dyn,
tcl,
current_x,
x_solver,
solver,
controller,
max_steps,
adjoint,
)
current_x = x_dyn[-1, :]
x_solver = x_dyn[-1, :]
else:
x_dyn = jnp.repeat(
current_x.reshape(1, -1), ts_dyn.shape[0], axis=0
x_solver.reshape(1, -1), ts_dyn.shape[0], axis=0
)
stats_dyn = None

# Post-equilibration
if ts_posteq.shape[0] > 0:
current_x, stats_posteq = self._eq(
p, tcl, current_x, solver, controller, max_steps
if ts_posteq.shape[0]:
x_solver, stats_posteq = self._eq(

Check warning on line 534 in python/sdist/amici/jax/model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L534

Added line #L534 was not covered by tests
p, tcl, x_solver, solver, controller, max_steps
)
else:
stats_posteq = None

x_posteq = jnp.repeat(
current_x.reshape(1, -1), ts_posteq.shape[0], axis=0
x_solver.reshape(1, -1), ts_posteq.shape[0], axis=0
)

ts = jnp.concatenate((ts_init, ts_dyn, ts_posteq), axis=0)
x = jnp.concatenate((x_preq, x_dyn, x_posteq), axis=0)

nllhs = self._nllhs(ts, x, p, tcl, my, iys)
llh = -jnp.sum(nllhs)
return {
"llh": llh,
"nllhs": nllhs,
"x": self._x_rdatas(x, tcl),
"x_solver": x,
"y": self._ys(ts, x, p, tcl, iys),
"sigmay": self._sigmays(ts, x, p, tcl, iys),
"x0": self._x_rdata(x[0, :], tcl),
"x0_solver": x[0, :],
"tcl": tcl,
"res": self._ys(ts, x, p, tcl, iys) - my,
}[ret], dict(

stats = dict(
ts=ts,
x=x,
llh=llh,
stats_dyn=stats_dyn,
stats_posteq=stats_posteq,
)
if ret == ReturnValue.llh:
output = llh
elif ret == ReturnValue.nllhs:
output = nllhs

Check warning on line 560 in python/sdist/amici/jax/model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L560

Added line #L560 was not covered by tests
elif ret == ReturnValue.x:
output = self._x_rdatas(x, tcl)
elif ret == ReturnValue.x_solver:
output = x

Check warning on line 564 in python/sdist/amici/jax/model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L564

Added line #L564 was not covered by tests
elif ret == ReturnValue.y:
output = self._ys(ts, x, p, tcl, iys)
elif ret == ReturnValue.sigmay:
output = self._sigmays(ts, x, p, tcl, iys)

Check warning on line 568 in python/sdist/amici/jax/model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L568

Added line #L568 was not covered by tests
elif ret == ReturnValue.x0:
output = self._x_rdata(x[0, :], tcl)
elif ret == ReturnValue.x0_solver:
output = x[0, :]

Check warning on line 572 in python/sdist/amici/jax/model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L572

Added line #L572 was not covered by tests
elif ret == ReturnValue.tcl:
output = tcl

Check warning on line 574 in python/sdist/amici/jax/model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L574

Added line #L574 was not covered by tests
elif ret in (ReturnValue.res, ReturnValue.chi2):
obs_trafo = jax.vmap(
lambda y, iy_trafo: jnp.array(
# needs to follow order in amici.jax.petab.SCALE_TO_INT
[y, safe_log(y), safe_log(y) / jnp.log(10)]
)
.at[iy_trafo]
.get(),
)
ys_obj = obs_trafo(self._ys(ts, x, p, tcl, iys), iy_trafos)
m_obj = obs_trafo(my, iy_trafos)
if ret == ReturnValue.chi2:
output = jnp.sum(
jnp.square(ys_obj - m_obj)
/ jnp.square(self._sigmays(ts, x, p, tcl, iys))
)
else:
output = ys_obj - m_obj
else:
raise NotImplementedError(f"Return value {ret} not implemented.")

return output, stats

@eqx.filter_jit
def preequilibrate_condition(
self,
p: jt.Float[jt.Array, "np"],
x_reinit: jt.Float[jt.Array, "*nx"],
mask_reinit: jt.Bool[jt.Array, "*nx"],
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
max_steps: int | jnp.int_,
Expand All @@ -572,6 +621,8 @@
"""
# Pre-equilibration
x0 = self._x0(p)
if x_reinit.shape[0]:
x0 = jnp.where(mask_reinit, x_reinit, x0)
tcl = self._tcl(x0, p)
current_x = self._x_solver(x0)
current_x, stats_preeq = self._eq(
Expand Down
2 changes: 1 addition & 1 deletion python/sdist/amici/jax/ode_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def _generate_jax_code(self) -> None:
# assign named variables from a jax array
**_jax_variable_assignments(self.model, sym_names),
# tuple of variable names (ids as they are unique)
**_jax_variable_ids(self.model, ("p", "k", "y", "x")),
**_jax_variable_ids(self.model, ("p", "k", "y", "x_rdata")),
**{
"MODEL_NAME": self.model_name,
# keep track of the API version that the model was generated with so we
Expand Down
Loading
Loading