Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid silent preequilibration failure in JAX #2631

Merged
merged 14 commits into from
Dec 19, 2024
1 change: 1 addition & 0 deletions python/examples/example_jax_petab/ExampleJaxPEtab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@
" iy_trafos=jnp.array(iy_trafos),\n",
" solver=diffrax.Kvaerno5(),\n",
" controller=diffrax.PIDController(atol=1e-8, rtol=1e-8),\n",
" steady_state_event=diffrax.steady_state_event(),\n",
" max_steps=2**10,\n",
" adjoint=diffrax.DirectAdjoint(),\n",
" ret=ReturnValue.y, # Return observables\n",
Expand Down
52 changes: 48 additions & 4 deletions python/sdist/amici/jax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import jax
import jaxtyping as jt

from collections.abc import Callable


class ReturnValue(enum.Enum):
llh = "log-likelihood"
Expand All @@ -32,6 +34,13 @@ class JAXModel(eqx.Module):
JAXModel provides an abstract base class for a JAX-based implementation of an AMICI model. The class implements
routines for simulation and evaluation of derived quantities, model specific implementations need to be provided by
classes inheriting from JAXModel.

:ivar api_version:
API version of the derived class. Needs to match the API version of the base class (MODEL_API_VERSION).
:ivar MODEL_API_VERSION:
API version of the base class.
:ivar jax_py_file:
Path to the JAX model file.
"""

MODEL_API_VERSION = "0.0.2"
Expand Down Expand Up @@ -248,6 +257,9 @@ def _eq(
x0: jt.Float[jt.Array, "nxs"],
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
steady_state_event: Callable[
..., diffrax._custom_types.BoolScalarLike
],
max_steps: jnp.int_,
) -> tuple[jt.Float[jt.Array, "1 nxs"], dict]:
"""
Expand Down Expand Up @@ -278,10 +290,20 @@ def _eq(
stepsize_controller=controller,
max_steps=max_steps,
adjoint=diffrax.DirectAdjoint(),
event=diffrax.Event(cond_fn=diffrax.steady_state_event()),
event=diffrax.Event(
cond_fn=steady_state_event,
),
throw=False,
)
return sol.ys[-1, :], sol.stats
# If the event was triggered, the event mask is True and the solution is the steady state. Otherwise, the
# solution is the last state and the event mask is False. In the latter case, we return inf for the steady
# state.
ys = jnp.where(
sol.event_mask,
sol.ys[-1, :],
jnp.inf * jnp.ones_like(sol.ys[-1, :]),
)
return ys, sol.stats

def _solve(
self,
Expand Down Expand Up @@ -450,6 +472,9 @@ def simulate_condition(
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
adjoint: diffrax.AbstractAdjoint,
steady_state_event: Callable[
..., diffrax._custom_types.BoolScalarLike
],
max_steps: int | jnp.int_,
x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]),
mask_reinit: jt.Bool[jt.Array, "*nx"] = jnp.array([]),
Expand Down Expand Up @@ -525,7 +550,13 @@ def simulate_condition(
# Post-equilibration
if ts_posteq.shape[0]:
x_solver, stats_posteq = self._eq(
p, tcl, x_solver, solver, controller, max_steps
p,
tcl,
x_solver,
solver,
controller,
steady_state_event,
max_steps,
)
else:
stats_posteq = None
Expand Down Expand Up @@ -596,13 +627,20 @@ def preequilibrate_condition(
mask_reinit: jt.Bool[jt.Array, "*nx"],
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
steady_state_event: Callable[
..., diffrax._custom_types.BoolScalarLike
],
max_steps: int | jnp.int_,
) -> tuple[jt.Float[jt.Array, "nx"], dict]:
r"""
Simulate a condition.

:param p:
parameters for simulation ordered according to ids in :ivar parameter_ids:
:param x_reinit:
re-initialized state vector. If not provided, the state vector is not re-initialized.
:param mask_reinit:
mask for re-initialization. If `True`, the corresponding state variable is re-initialized.
:param solver:
ODE solver
:param controller:
Expand All @@ -619,7 +657,13 @@ def preequilibrate_condition(
tcl = self._tcl(x0, p)
current_x = self._x_solver(x0)
current_x, stats_preeq = self._eq(
p, tcl, current_x, solver, controller, max_steps
p,
tcl,
current_x,
solver,
controller,
steady_state_event,
max_steps,
)

return self._x_rdata(current_x, tcl), dict(stats_preeq=stats_preeq)
Expand Down
26 changes: 24 additions & 2 deletions python/sdist/amici/jax/petab.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from numbers import Number
from collections.abc import Iterable
from pathlib import Path
from collections.abc import Callable


import diffrax
Expand Down Expand Up @@ -465,6 +466,9 @@ def run_simulation(
simulation_condition: tuple[str, ...],
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
steady_state_event: Callable[
..., diffrax._custom_types.BoolScalarLike
],
max_steps: jnp.int_,
x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]), # noqa: F821, F722
ret: ReturnValue = ReturnValue.llh,
Expand Down Expand Up @@ -507,6 +511,7 @@ def run_simulation(
solver=solver,
controller=controller,
max_steps=max_steps,
steady_state_event=steady_state_event,
adjoint=diffrax.RecursiveCheckpointAdjoint()
if ret in (ReturnValue.llh, ReturnValue.chi2)
else diffrax.DirectAdjoint(),
Expand All @@ -518,6 +523,9 @@ def run_preequilibration(
simulation_condition: str,
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
steady_state_event: Callable[
..., diffrax._custom_types.BoolScalarLike
],
max_steps: jnp.int_,
) -> tuple[jt.Float[jt.Array, "nx"], dict]: # noqa: F821
"""
Expand All @@ -539,12 +547,13 @@ def run_preequilibration(
simulation_condition, p
)
return self.model.preequilibrate_condition(
p=eqx.debug.backward_nan(p),
p=p,
mask_reinit=mask_reinit,
x_reinit=x_reinit,
solver=solver,
controller=controller,
max_steps=max_steps,
steady_state_event=steady_state_event,
)


Expand All @@ -555,6 +564,9 @@ def run_simulations(
controller: diffrax.AbstractStepSizeController = diffrax.PIDController(
**DEFAULT_CONTROLLER_SETTINGS
),
steady_state_event: Callable[
..., diffrax._custom_types.BoolScalarLike
] = diffrax.steady_state_event(),
max_steps: int = 2**10,
ret: ReturnValue | str = ReturnValue.llh,
):
Expand All @@ -569,6 +581,9 @@ def run_simulations(
ODE solver to use for simulation.
:param controller:
Step size controller to use for simulation.
:param steady_state_event:
Steady state event function to use for pre-/post-equilibration. Allows customisation of the steady state
condition, see :func:`diffrax.steady_state_event` for details.
:param max_steps:
Maximum number of steps to take during simulation.
:param ret:
Expand All @@ -583,7 +598,9 @@ def run_simulations(
simulation_conditions = problem.get_all_simulation_conditions()

preeqs = {
sc: problem.run_preequilibration(sc, solver, controller, max_steps)
sc: problem.run_preequilibration(
sc, solver, controller, steady_state_event, max_steps
)
# only run preequilibration once per condition
for sc in {sc[1] for sc in simulation_conditions if len(sc) > 1}
}
Expand All @@ -593,6 +610,7 @@ def run_simulations(
sc,
solver,
controller,
steady_state_event,
max_steps,
preeqs.get(sc[1])[0] if len(sc) > 1 else jnp.array([]),
ret=ret,
Expand All @@ -617,6 +635,9 @@ def petab_simulate(
controller: diffrax.AbstractStepSizeController = diffrax.PIDController(
**DEFAULT_CONTROLLER_SETTINGS
),
steady_state_event: Callable[
..., diffrax._custom_types.BoolScalarLike
] = diffrax.steady_state_event(),
max_steps: int = 2**10,
):
"""
Expand All @@ -637,6 +658,7 @@ def petab_simulate(
problem,
solver=solver,
controller=controller,
steady_state_event=steady_state_event,
max_steps=max_steps,
ret=ReturnValue.y,
)
Expand Down
26 changes: 25 additions & 1 deletion python/tests/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
import diffrax
import numpy as np
from beartype import beartype
from petab.v1.C import PREEQUILIBRATION_CONDITION_ID, SIMULATION_CONDITION_ID

from amici.pysb_import import pysb2amici, pysb2jax
from amici.testing import TemporaryDirectoryWinSafe, skip_on_valgrind
from amici.petab.petab_import import import_petab_problem
from amici.jax import JAXProblem, ReturnValue
from amici.jax import JAXProblem, ReturnValue, run_simulations
from numpy.testing import assert_allclose
from test_petab_objective import lotka_volterra # noqa: F401

Expand Down Expand Up @@ -198,6 +199,7 @@ def check_fields_jax(
"solver": diffrax.Kvaerno5(),
"controller": diffrax.PIDController(atol=ATOL_SIM, rtol=RTOL_SIM),
"adjoint": diffrax.RecursiveCheckpointAdjoint(),
"steady_state_event": diffrax.steady_state_event(),
"max_steps": 2**8, # max_steps
}
fun = beartype(jax_model.simulate_condition)
Expand Down Expand Up @@ -266,6 +268,28 @@ def check_fields_jax(
)


def test_preequilibration_failure(lotka_volterra): # noqa: F811
petab_problem = lotka_volterra
# oscillating system, preequilibation should fail when interaction is active
with TemporaryDirectoryWinSafe(prefix="normal") as model_dir:
jax_model = import_petab_problem(
petab_problem, jax=True, model_output_dir=model_dir
)
jax_problem = JAXProblem(jax_model, petab_problem)
r = run_simulations(jax_problem)
assert not np.isinf(r[0].item())
petab_problem.measurement_df[PREEQUILIBRATION_CONDITION_ID] = (
petab_problem.measurement_df[SIMULATION_CONDITION_ID]
)
with TemporaryDirectoryWinSafe(prefix="failure") as model_dir:
jax_model = import_petab_problem(
petab_problem, jax=True, model_output_dir=model_dir
)
jax_problem = JAXProblem(jax_model, petab_problem)
r = run_simulations(jax_problem)
assert np.isinf(r[0].item())


@skip_on_valgrind
def test_serialisation(lotka_volterra): # noqa: F811
petab_problem = lotka_volterra
Expand Down
4 changes: 3 additions & 1 deletion tests/benchmark-models/test_petab_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from functools import partial
from pathlib import Path

import fiddy
import amici
import numpy as np
Expand Down Expand Up @@ -342,8 +343,9 @@ def test_jax_llh(benchmark_problem):
[problem_parameters[pid] for pid in jax_problem.parameter_ids]
),
)
llh_jax, _ = beartype(run_simulations)(jax_problem)

if problem_id in problems_for_gradient_check:
beartype(run_simulations)(jax_problem)
(llh_jax, _), sllh_jax = eqx.filter_value_and_grad(
run_simulations, has_aux=True
)(jax_problem)
Expand Down
15 changes: 12 additions & 3 deletions tests/petab_test_suite/test_petab_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import logging
import sys

import diffrax

import amici
import pandas as pd
import petab.v1 as petab
Expand Down Expand Up @@ -68,10 +70,17 @@ def _test_case(case, model_type, version, jax):
if jax:
from amici.jax import JAXProblem, run_simulations, petab_simulate

steady_state_event = diffrax.steady_state_event(rtol=1e-6, atol=1e-6)
jax_problem = JAXProblem(model, problem)
llh, ret = run_simulations(jax_problem)
chi2, _ = run_simulations(jax_problem, ret="chi2")
simulation_df = petab_simulate(jax_problem)
llh, ret = run_simulations(
jax_problem, steady_state_event=steady_state_event
)
chi2, _ = run_simulations(
jax_problem, ret="chi2", steady_state_event=steady_state_event
)
simulation_df = petab_simulate(
jax_problem, steady_state_event=steady_state_event
)
simulation_df.rename(
columns={petab.SIMULATION: petab.MEASUREMENT}, inplace=True
)
Expand Down
Loading