Skip to content

Commit

Permalink
State reinitialisation in JAX (#2619)
Browse files Browse the repository at this point in the history
* disentangle sim & preeq

* disentangle sim & preeq

* run preequilibration once

* fix symlink

* separate default dirs for jax/cpp, honour model dir/name

* fix notebook

* fix path SNAFU

* fix models without preequilibration

* fix tests

* fixup

* fix doc typehints

* fix notebook

* implement jax-based reinitialisation

* add more defaults & doc

* fix state ids

* fix template

* Update model.py

* breaking jax release

* add jax runner to petab testsuite & fix

* fix notebook

* refactor petab test cases

* fix parameter unscaling

* fixups

* refactor & simplify

* fixup

* fix notebook

* fixup

* Update petab.py
FFroehlich authored Dec 8, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 449041d commit f9e64de
Showing 12 changed files with 470 additions and 124 deletions.
7 changes: 4 additions & 3 deletions python/examples/example_jax_petab/ExampleJaxPEtab.ipynb
Original file line number Diff line number Diff line change
@@ -352,12 +352,13 @@
"source": [
"import jax.numpy as jnp\n",
"import diffrax\n",
"from amici.jax import ReturnValue\n",
"\n",
"# Define the simulation condition\n",
"simulation_condition = (\"model1_data1\",)\n",
"\n",
"# Load condition-specific data\n",
"ts_init, ts_dyn, ts_posteq, my, iys = jax_problem._measurements[\n",
"ts_init, ts_dyn, ts_posteq, my, iys, iy_trafos = jax_problem._measurements[\n",
" simulation_condition\n",
"]\n",
"\n",
@@ -375,12 +376,12 @@
" ts_posteq=ts_posteq,\n",
" my=jnp.array(my),\n",
" iys=jnp.array(iys),\n",
" x_preeq=jnp.array([]),\n",
" iy_trafos=jnp.array(iy_trafos),\n",
" solver=diffrax.Kvaerno5(),\n",
" controller=diffrax.PIDController(atol=1e-8, rtol=1e-8),\n",
" max_steps=2**10,\n",
" adjoint=diffrax.DirectAdjoint(),\n",
" ret=\"y\", # Return observables\n",
" ret=ReturnValue.y, # Return observables\n",
" )[0]\n",
"\n",
"\n",
15 changes: 13 additions & 2 deletions python/sdist/amici/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,12 @@

from warnings import warn

from amici.jax.petab import JAXProblem, run_simulations
from amici.jax.petab import (
JAXProblem,
run_simulations,
petab_simulate,
ReturnValue,
)
from amici.jax.model import JAXModel

warn(
@@ -18,4 +23,10 @@
stacklevel=2,
)

__all__ = ["JAXModel", "JAXProblem", "run_simulations"]
__all__ = [
"JAXModel",
"JAXProblem",
"run_simulations",
"petab_simulate",
"ReturnValue",
]
2 changes: 1 addition & 1 deletion python/sdist/amici/jax/jax.template.py
Original file line number Diff line number Diff line change
@@ -97,7 +97,7 @@ def observable_ids(self):

@property
def state_ids(self):
return TPL_X_IDS
return TPL_X_RDATA_IDS

@property
def parameter_ids(self):
135 changes: 93 additions & 42 deletions python/sdist/amici/jax/model.py
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@

from abc import abstractmethod
from pathlib import Path
import enum

import diffrax
import equinox as eqx
@@ -12,6 +13,20 @@
import jaxtyping as jt


class ReturnValue(enum.Enum):
llh = "log-likelihood"
nllhs = "pointwise negative log-likelihood"
x0 = "full initial state vector"
x0_solver = "reduced initial state vector"
x = "full state vector"
x_solver = "reduced state vector"
y = "observables"
sigmay = "standard deviations of the observables"
tcl = "total values for conservation laws"
res = "residuals"
chi2 = "sum(((observed - simulated) / sigma ) ** 2)"


class JAXModel(eqx.Module):
"""
JAXModel provides an abstract base class for a JAX-based implementation of an AMICI model. The class implements
@@ -432,12 +447,15 @@ def simulate_condition(
ts_posteq: jt.Float[jt.Array, "nt_posteq"],
my: jt.Float[jt.Array, "nt"],
iys: jt.Int[jt.Array, "nt"],
x_preeq: jt.Float[jt.Array, "nx"],
iy_trafos: jt.Int[jt.Array, "nt"],
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
adjoint: diffrax.AbstractAdjoint,
max_steps: int | jnp.int_,
ret: str = "llh",
x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]),
mask_reinit: jt.Bool[jt.Array, "*nx"] = jnp.array([]),
x_reinit: jt.Float[jt.Array, "*nx"] = jnp.array([]),
ret: ReturnValue = ReturnValue.llh,
) -> tuple[jt.Float[jt.Array, "nt *nx"] | jnp.float_, dict]:
r"""
Simulate a condition.
@@ -458,6 +476,13 @@ def simulate_condition(
observed data
:param iys:
indices of the observables according to ordering in :ivar observable_ids:
:param x_preeq:
initial state vector for pre-equilibration. If not provided, the initial state vector is computed using
:meth:`_x0`.
:param mask_reinit:
mask for re-initialization. If `True`, the corresponding state variable is re-initialized.
:param x_reinit:
re-initialized state vector. If not provided, the state vector is not re-initialized.
:param solver:
ODE solver
:param controller:
@@ -468,90 +493,114 @@ def simulate_condition(
:param max_steps:
maximum number of solver steps
:param ret:
which output to return. Valid values are
- `llh`: log-likelihood (default)
- `nllhs`: negative log-likelihood at each time point
- `x0`: full initial state vector (after pre-equilibration)
- `x0_solver`: reduced initial state vector (after pre-equilibration)
- `x`: full state vector
- `x_solver`: reduced state vector
- `y`: observables
- `sigmay`: standard deviations of the observables
- `tcl`: total values for conservation laws (at final timepoint)
- `res`: residuals (observed - simulated)
which output to return. See :class:`ReturnValue` for available options.
:return:
output according to `ret` and statistics
"""
# Pre-equilibration
if x_preeq.shape[0] > 0:
current_x = self._x_solver(x_preeq)
# update tcl with new parameters
tcl = self._tcl(x_preeq, p)
if x_preeq.shape[0]:
x = x_preeq
else:
x0 = self._x0(p)
current_x = self._x_solver(x0)
x = self._x0(p)

# Re-initialization
if x_reinit.shape[0]:
x = jnp.where(mask_reinit, x_reinit, x)
x_solver = self._x_solver(x)
tcl = self._tcl(x, p)

tcl = self._tcl(x0, p)
x_preq = jnp.repeat(current_x.reshape(1, -1), ts_init.shape[0], axis=0)
x_preq = jnp.repeat(x_solver.reshape(1, -1), ts_init.shape[0], axis=0)

# Dynamic simulation
if ts_dyn.shape[0] > 0:
if ts_dyn.shape[0]:
x_dyn, stats_dyn = self._solve(
p,
ts_dyn,
tcl,
current_x,
x_solver,
solver,
controller,
max_steps,
adjoint,
)
current_x = x_dyn[-1, :]
x_solver = x_dyn[-1, :]
else:
x_dyn = jnp.repeat(
current_x.reshape(1, -1), ts_dyn.shape[0], axis=0
x_solver.reshape(1, -1), ts_dyn.shape[0], axis=0
)
stats_dyn = None

# Post-equilibration
if ts_posteq.shape[0] > 0:
current_x, stats_posteq = self._eq(
p, tcl, current_x, solver, controller, max_steps
if ts_posteq.shape[0]:
x_solver, stats_posteq = self._eq(
p, tcl, x_solver, solver, controller, max_steps
)
else:
stats_posteq = None

x_posteq = jnp.repeat(
current_x.reshape(1, -1), ts_posteq.shape[0], axis=0
x_solver.reshape(1, -1), ts_posteq.shape[0], axis=0
)

ts = jnp.concatenate((ts_init, ts_dyn, ts_posteq), axis=0)
x = jnp.concatenate((x_preq, x_dyn, x_posteq), axis=0)

nllhs = self._nllhs(ts, x, p, tcl, my, iys)
llh = -jnp.sum(nllhs)
return {
"llh": llh,
"nllhs": nllhs,
"x": self._x_rdatas(x, tcl),
"x_solver": x,
"y": self._ys(ts, x, p, tcl, iys),
"sigmay": self._sigmays(ts, x, p, tcl, iys),
"x0": self._x_rdata(x[0, :], tcl),
"x0_solver": x[0, :],
"tcl": tcl,
"res": self._ys(ts, x, p, tcl, iys) - my,
}[ret], dict(

stats = dict(
ts=ts,
x=x,
llh=llh,
stats_dyn=stats_dyn,
stats_posteq=stats_posteq,
)
if ret == ReturnValue.llh:
output = llh
elif ret == ReturnValue.nllhs:
output = nllhs
elif ret == ReturnValue.x:
output = self._x_rdatas(x, tcl)
elif ret == ReturnValue.x_solver:
output = x
elif ret == ReturnValue.y:
output = self._ys(ts, x, p, tcl, iys)
elif ret == ReturnValue.sigmay:
output = self._sigmays(ts, x, p, tcl, iys)
elif ret == ReturnValue.x0:
output = self._x_rdata(x[0, :], tcl)
elif ret == ReturnValue.x0_solver:
output = x[0, :]
elif ret == ReturnValue.tcl:
output = tcl
elif ret in (ReturnValue.res, ReturnValue.chi2):
obs_trafo = jax.vmap(
lambda y, iy_trafo: jnp.array(
# needs to follow order in amici.jax.petab.SCALE_TO_INT
[y, safe_log(y), safe_log(y) / jnp.log(10)]
)
.at[iy_trafo]
.get(),
)
ys_obj = obs_trafo(self._ys(ts, x, p, tcl, iys), iy_trafos)
m_obj = obs_trafo(my, iy_trafos)
if ret == ReturnValue.chi2:
output = jnp.sum(
jnp.square(ys_obj - m_obj)
/ jnp.square(self._sigmays(ts, x, p, tcl, iys))
)
else:
output = ys_obj - m_obj
else:
raise NotImplementedError(f"Return value {ret} not implemented.")

return output, stats

@eqx.filter_jit
def preequilibrate_condition(
self,
p: jt.Float[jt.Array, "np"],
x_reinit: jt.Float[jt.Array, "*nx"],
mask_reinit: jt.Bool[jt.Array, "*nx"],
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
max_steps: int | jnp.int_,
@@ -572,6 +621,8 @@ def preequilibrate_condition(
"""
# Pre-equilibration
x0 = self._x0(p)
if x_reinit.shape[0]:
x0 = jnp.where(mask_reinit, x_reinit, x0)
tcl = self._tcl(x0, p)
current_x = self._x_solver(x0)
current_x, stats_preeq = self._eq(
2 changes: 1 addition & 1 deletion python/sdist/amici/jax/ode_export.py
Original file line number Diff line number Diff line change
@@ -226,7 +226,7 @@ def _generate_jax_code(self) -> None:
# assign named variables from a jax array
**_jax_variable_assignments(self.model, sym_names),
# tuple of variable names (ids as they are unique)
**_jax_variable_ids(self.model, ("p", "k", "y", "x")),
**_jax_variable_ids(self.model, ("p", "k", "y", "x_rdata")),
**{
"MODEL_NAME": self.model_name,
# keep track of the API version that the model was generated with so we
299 changes: 279 additions & 20 deletions python/sdist/amici/jax/petab.py

Large diffs are not rendered by default.

13 changes: 8 additions & 5 deletions python/sdist/amici/petab/sbml_import.py
Original file line number Diff line number Diff line change
@@ -348,11 +348,14 @@ def import_model_sbml(
_workaround_observable_parameters(
observables, sigmas, sbml_model, output_parameter_defaults
)
fixed_parameters = _workaround_initial_states(
petab_problem=petab_problem,
sbml_model=sbml_model,
**kwargs,
)
if not jax:
fixed_parameters = _workaround_initial_states(
petab_problem=petab_problem,
sbml_model=sbml_model,
**kwargs,
)
else:
fixed_parameters = []

fixed_parameters.extend(
_get_fixed_parameters_sbml(
2 changes: 1 addition & 1 deletion python/sdist/pyproject.toml
Original file line number Diff line number Diff line change
@@ -83,7 +83,7 @@ examples = [
"scipy",
]
jax = [
"jax>=0.4.34",
"jax>=0.4.34,<0.4.36",
"jaxlib>=0.4.34",
"diffrax>=0.6.0",
"jaxtyping>=0.2.34",
6 changes: 4 additions & 2 deletions python/tests/test_jax.py
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@
from amici.pysb_import import pysb2amici, pysb2jax
from amici.testing import TemporaryDirectoryWinSafe, skip_on_valgrind
from amici.petab.petab_import import import_petab_problem
from amici.jax import JAXProblem
from amici.jax import JAXProblem, ReturnValue
from numpy.testing import assert_allclose
from test_petab_objective import lotka_volterra # noqa: F401

@@ -177,6 +177,7 @@ def check_fields_jax(
my = my.flatten()
ts = ts.flatten()
iys = iys.flatten()
iy_trafos = np.zeros_like(iys)

ts_init = ts[ts == 0]
ts_dyn = ts[ts > 0]
@@ -194,6 +195,7 @@ def check_fields_jax(
"ts_posteq": jnp.array(ts_posteq),
"my": jnp.array(my),
"iys": jnp.array(iys),
"iy_trafos": jnp.array(iy_trafos),
"x_preeq": jnp.array([]),
"solver": diffrax.Kvaerno5(),
"controller": diffrax.PIDController(atol=ATOL_SIM, rtol=RTOL_SIM),
@@ -206,7 +208,7 @@ def check_fields_jax(
okwargs = kwargs | {
"adjoint": diffrax.DirectAdjoint(),
"max_steps": 2**8,
"ret": output,
"ret": ReturnValue[output],
}
if sensi_order == amici.SensitivityOrder.none:
r_jax[output] = fun(p, **okwargs)[0]
7 changes: 4 additions & 3 deletions tests/benchmark-models/test_petab_benchmark.py
Original file line number Diff line number Diff line change
@@ -340,12 +340,13 @@ def test_jax_llh(benchmark_problem):
[problem_parameters[pid] for pid in jax_problem.parameter_ids]
),
)
llh_jax, _ = beartype(run_simulations)(jax_problem)
if problem_id in problems_for_gradient_check:
(llh_jax, _), sllh_jax = eqx.filter_jit(
eqx.filter_value_and_grad(run_simulations, has_aux=True)
(llh_jax, _), sllh_jax = eqx.filter_value_and_grad(
run_simulations, has_aux=True
)(jax_problem)
else:
llh_jax, _ = beartype(eqx.filter_jit(run_simulations))(jax_problem)
llh_jax, _ = beartype(run_simulations)(jax_problem)

np.testing.assert_allclose(
llh_jax,
16 changes: 9 additions & 7 deletions tests/petab_test_suite/conftest.py
Original file line number Diff line number Diff line change
@@ -60,7 +60,7 @@ def pytest_generate_tests(metafunc):

if metafunc.config.getoption("--only-sbml"):
argvalues = [
(case, "sbml", version)
(case, "sbml", version, False)
for version in ("v1.0.0", "v2.0.0")
for case in (
test_numbers
@@ -70,7 +70,7 @@ def pytest_generate_tests(metafunc):
]
elif metafunc.config.getoption("--only-pysb"):
argvalues = [
(case, "pysb", "v2.0.0")
(case, "pysb", "v2.0.0", False)
for case in (
test_numbers
if test_numbers
@@ -81,8 +81,10 @@ def pytest_generate_tests(metafunc):
argvalues = []
for version in ("v1.0.0", "v2.0.0"):
for format in ("sbml", "pysb"):
argvalues.extend(
(case, format, version)
for case in test_numbers or get_cases(format, version)
)
metafunc.parametrize("case,model_type,version", argvalues)
for jax in (True, False):
argvalues.extend(
(case, format, version, jax)
for case in test_numbers
or get_cases(format, version)
)
metafunc.parametrize("case,model_type,version,jax", argvalues)
90 changes: 53 additions & 37 deletions tests/petab_test_suite/test_petab_suite.py
Original file line number Diff line number Diff line change
@@ -23,10 +23,10 @@
logger.addHandler(stream_handler)


def test_case(case, model_type, version):
def test_case(case, model_type, version, jax):
"""Wrapper for _test_case for handling test outcomes"""
try:
_test_case(case, model_type, version)
_test_case(case, model_type, version, jax)
except Exception as e:
if isinstance(
e, NotImplementedError
@@ -41,10 +41,10 @@ def test_case(case, model_type, version):
raise e


def _test_case(case, model_type, version):
def _test_case(case, model_type, version, jax):
"""Run a single PEtab test suite case"""
case = petabtests.test_id_str(case)
logger.debug(f"Case {case} [{model_type}] [{version}]")
logger.debug(f"Case {case} [{model_type}] [{version}] [{jax}]")

# load
case_dir = petabtests.get_case_dir(case, model_type, version)
@@ -57,34 +57,46 @@ def _test_case(case, model_type, version):
model_name = (
f"petab_{model_type}_test_case_{case}" f"_{version.replace('.', '_')}"
)
model_output_dir = f"amici_models/{model_name}"
model_output_dir = f"amici_models/{model_name}" + ("_jax" if jax else "")
model = import_petab_problem(
petab_problem=problem,
model_output_dir=model_output_dir,
model_name=model_name,
compile_=True,
jax=jax,
)
solver = model.getSolver()
solver.setSteadyStateToleranceFactor(1.0)
problem_parameters = dict(
zip(problem.x_free_ids, problem.x_nominal_free, strict=True)
)
if jax:
from amici.jax import JAXProblem, run_simulations, petab_simulate

jax_problem = JAXProblem(model, problem)
llh, ret = run_simulations(jax_problem)
chi2, _ = run_simulations(jax_problem, ret="chi2")
simulation_df = petab_simulate(jax_problem)
simulation_df.rename(
columns={petab.SIMULATION: petab.MEASUREMENT}, inplace=True
)
else:
solver = model.getSolver()
solver.setSteadyStateToleranceFactor(1.0)
problem_parameters = dict(
zip(problem.x_free_ids, problem.x_nominal_free, strict=True)
)

# simulate
ret = simulate_petab(
problem,
model,
problem_parameters=problem_parameters,
solver=solver,
log_level=logging.DEBUG,
)
# simulate
ret = simulate_petab(
problem,
model,
problem_parameters=problem_parameters,
solver=solver,
log_level=logging.DEBUG,
)

rdatas = ret["rdatas"]
chi2 = sum(rdata["chi2"] for rdata in rdatas)
llh = ret["llh"]
simulation_df = rdatas_to_measurement_df(
rdatas, model, problem.measurement_df
)
rdatas = ret["rdatas"]
chi2 = sum(rdata["chi2"] for rdata in rdatas)
llh = ret["llh"]
simulation_df = rdatas_to_measurement_df(
rdatas, model, problem.measurement_df
)
petab.check_measurement_df(simulation_df, problem.observable_df)
simulation_df = simulation_df.rename(
columns={petab.MEASUREMENT: petab.SIMULATION}
@@ -142,7 +154,10 @@ def _test_case(case, model_type, version):
f"LLH: simulated: {llh}, expected: {gt_llh}, " f"match = {llhs_match}",
)

check_derivatives(problem, model, solver, problem_parameters)
if jax:
pass # skip derivative checks for now
else:
check_derivatives(problem, model, solver, problem_parameters)

if not all([llhs_match, simulations_match]) or not chi2s_match:
logger.error(f"Case {case} failed.")
@@ -196,18 +211,19 @@ def run():
n_skipped = 0
n_total = 0
for version in ("v1.0.0", "v2.0.0"):
cases = petabtests.get_cases("sbml", version=version)
n_total += len(cases)
for case in cases:
try:
test_case(case, "sbml", version=version)
n_success += 1
except Skipped:
n_skipped += 1
except Exception as e:
# run all despite failures
logger.error(f"Case {case} failed.")
logger.error(e)
for jax in (False, True):
cases = petabtests.get_cases("sbml", version=version)
n_total += len(cases)
for case in cases:
try:
test_case(case, "sbml", version=version, jax=jax)
n_success += 1
except Skipped:
n_skipped += 1
except Exception as e:
# run all despite failures
logger.error(f"Case {case} failed.")
logger.error(e)

logger.info(f"{n_success} / {n_total} successful, " f"{n_skipped} skipped")
if n_success != len(cases):

0 comments on commit f9e64de

Please sign in to comment.