From f9e64def3aa46978d26aa0dbdd689608141632da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Sun, 8 Dec 2024 08:51:48 +0000 Subject: [PATCH] State reinitialisation in JAX (#2619) * disentangle sim & preeq * disentangle sim & preeq * run preequilibration once * fix symlink * separate default dirs for jax/cpp, honour model dir/name * fix notebook * fix path SNAFU * fix models without preequilibration * fix tests * fixup * fix doc typehints * fix notebook * implement jax-based reinitialisation * add more defaults & doc * fix state ids * fix template * Update model.py * breaking jax release * add jax runner to petab testsuite & fix * fix notebook * refactor petab test cases * fix parameter unscaling * fixups * refactor & simplify * fixup * fix notebook * fixup * Update petab.py --- .../example_jax_petab/ExampleJaxPEtab.ipynb | 7 +- python/sdist/amici/jax/__init__.py | 15 +- python/sdist/amici/jax/jax.template.py | 2 +- python/sdist/amici/jax/model.py | 135 +++++--- python/sdist/amici/jax/ode_export.py | 2 +- python/sdist/amici/jax/petab.py | 299 ++++++++++++++++-- python/sdist/amici/petab/sbml_import.py | 13 +- python/sdist/pyproject.toml | 2 +- python/tests/test_jax.py | 6 +- .../benchmark-models/test_petab_benchmark.py | 7 +- tests/petab_test_suite/conftest.py | 16 +- tests/petab_test_suite/test_petab_suite.py | 90 +++--- 12 files changed, 470 insertions(+), 124 deletions(-) diff --git a/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb b/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb index f6a4f10e98..1310091f4c 100644 --- a/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb +++ b/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb @@ -352,12 +352,13 @@ "source": [ "import jax.numpy as jnp\n", "import diffrax\n", + "from amici.jax import ReturnValue\n", "\n", "# Define the simulation condition\n", "simulation_condition = (\"model1_data1\",)\n", "\n", "# Load condition-specific data\n", - "ts_init, ts_dyn, ts_posteq, my, iys = jax_problem._measurements[\n", + "ts_init, ts_dyn, ts_posteq, my, iys, iy_trafos = jax_problem._measurements[\n", " simulation_condition\n", "]\n", "\n", @@ -375,12 +376,12 @@ " ts_posteq=ts_posteq,\n", " my=jnp.array(my),\n", " iys=jnp.array(iys),\n", - " x_preeq=jnp.array([]),\n", + " iy_trafos=jnp.array(iy_trafos),\n", " solver=diffrax.Kvaerno5(),\n", " controller=diffrax.PIDController(atol=1e-8, rtol=1e-8),\n", " max_steps=2**10,\n", " adjoint=diffrax.DirectAdjoint(),\n", - " ret=\"y\", # Return observables\n", + " ret=ReturnValue.y, # Return observables\n", " )[0]\n", "\n", "\n", diff --git a/python/sdist/amici/jax/__init__.py b/python/sdist/amici/jax/__init__.py index 8b67abda27..a5b5dc1cae 100644 --- a/python/sdist/amici/jax/__init__.py +++ b/python/sdist/amici/jax/__init__.py @@ -9,7 +9,12 @@ from warnings import warn -from amici.jax.petab import JAXProblem, run_simulations +from amici.jax.petab import ( + JAXProblem, + run_simulations, + petab_simulate, + ReturnValue, +) from amici.jax.model import JAXModel warn( @@ -18,4 +23,10 @@ stacklevel=2, ) -__all__ = ["JAXModel", "JAXProblem", "run_simulations"] +__all__ = [ + "JAXModel", + "JAXProblem", + "run_simulations", + "petab_simulate", + "ReturnValue", +] diff --git a/python/sdist/amici/jax/jax.template.py b/python/sdist/amici/jax/jax.template.py index d395715422..5d5521d222 100644 --- a/python/sdist/amici/jax/jax.template.py +++ b/python/sdist/amici/jax/jax.template.py @@ -97,7 +97,7 @@ def observable_ids(self): @property def state_ids(self): - return TPL_X_IDS + return TPL_X_RDATA_IDS @property def parameter_ids(self): diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index 8f9650ef0f..98e123b5f0 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -4,6 +4,7 @@ from abc import abstractmethod from pathlib import Path +import enum import diffrax import equinox as eqx @@ -12,6 +13,20 @@ import jaxtyping as jt +class ReturnValue(enum.Enum): + llh = "log-likelihood" + nllhs = "pointwise negative log-likelihood" + x0 = "full initial state vector" + x0_solver = "reduced initial state vector" + x = "full state vector" + x_solver = "reduced state vector" + y = "observables" + sigmay = "standard deviations of the observables" + tcl = "total values for conservation laws" + res = "residuals" + chi2 = "sum(((observed - simulated) / sigma ) ** 2)" + + class JAXModel(eqx.Module): """ JAXModel provides an abstract base class for a JAX-based implementation of an AMICI model. The class implements @@ -432,12 +447,15 @@ def simulate_condition( ts_posteq: jt.Float[jt.Array, "nt_posteq"], my: jt.Float[jt.Array, "nt"], iys: jt.Int[jt.Array, "nt"], - x_preeq: jt.Float[jt.Array, "nx"], + iy_trafos: jt.Int[jt.Array, "nt"], solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, adjoint: diffrax.AbstractAdjoint, max_steps: int | jnp.int_, - ret: str = "llh", + x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]), + mask_reinit: jt.Bool[jt.Array, "*nx"] = jnp.array([]), + x_reinit: jt.Float[jt.Array, "*nx"] = jnp.array([]), + ret: ReturnValue = ReturnValue.llh, ) -> tuple[jt.Float[jt.Array, "nt *nx"] | jnp.float_, dict]: r""" Simulate a condition. @@ -458,6 +476,13 @@ def simulate_condition( observed data :param iys: indices of the observables according to ordering in :ivar observable_ids: + :param x_preeq: + initial state vector for pre-equilibration. If not provided, the initial state vector is computed using + :meth:`_x0`. + :param mask_reinit: + mask for re-initialization. If `True`, the corresponding state variable is re-initialized. + :param x_reinit: + re-initialized state vector. If not provided, the state vector is not re-initialized. :param solver: ODE solver :param controller: @@ -468,61 +493,52 @@ def simulate_condition( :param max_steps: maximum number of solver steps :param ret: - which output to return. Valid values are - - `llh`: log-likelihood (default) - - `nllhs`: negative log-likelihood at each time point - - `x0`: full initial state vector (after pre-equilibration) - - `x0_solver`: reduced initial state vector (after pre-equilibration) - - `x`: full state vector - - `x_solver`: reduced state vector - - `y`: observables - - `sigmay`: standard deviations of the observables - - `tcl`: total values for conservation laws (at final timepoint) - - `res`: residuals (observed - simulated) + which output to return. See :class:`ReturnValue` for available options. :return: output according to `ret` and statistics """ - # Pre-equilibration - if x_preeq.shape[0] > 0: - current_x = self._x_solver(x_preeq) - # update tcl with new parameters - tcl = self._tcl(x_preeq, p) + if x_preeq.shape[0]: + x = x_preeq else: - x0 = self._x0(p) - current_x = self._x_solver(x0) + x = self._x0(p) + + # Re-initialization + if x_reinit.shape[0]: + x = jnp.where(mask_reinit, x_reinit, x) + x_solver = self._x_solver(x) + tcl = self._tcl(x, p) - tcl = self._tcl(x0, p) - x_preq = jnp.repeat(current_x.reshape(1, -1), ts_init.shape[0], axis=0) + x_preq = jnp.repeat(x_solver.reshape(1, -1), ts_init.shape[0], axis=0) # Dynamic simulation - if ts_dyn.shape[0] > 0: + if ts_dyn.shape[0]: x_dyn, stats_dyn = self._solve( p, ts_dyn, tcl, - current_x, + x_solver, solver, controller, max_steps, adjoint, ) - current_x = x_dyn[-1, :] + x_solver = x_dyn[-1, :] else: x_dyn = jnp.repeat( - current_x.reshape(1, -1), ts_dyn.shape[0], axis=0 + x_solver.reshape(1, -1), ts_dyn.shape[0], axis=0 ) stats_dyn = None # Post-equilibration - if ts_posteq.shape[0] > 0: - current_x, stats_posteq = self._eq( - p, tcl, current_x, solver, controller, max_steps + if ts_posteq.shape[0]: + x_solver, stats_posteq = self._eq( + p, tcl, x_solver, solver, controller, max_steps ) else: stats_posteq = None x_posteq = jnp.repeat( - current_x.reshape(1, -1), ts_posteq.shape[0], axis=0 + x_solver.reshape(1, -1), ts_posteq.shape[0], axis=0 ) ts = jnp.concatenate((ts_init, ts_dyn, ts_posteq), axis=0) @@ -530,28 +546,61 @@ def simulate_condition( nllhs = self._nllhs(ts, x, p, tcl, my, iys) llh = -jnp.sum(nllhs) - return { - "llh": llh, - "nllhs": nllhs, - "x": self._x_rdatas(x, tcl), - "x_solver": x, - "y": self._ys(ts, x, p, tcl, iys), - "sigmay": self._sigmays(ts, x, p, tcl, iys), - "x0": self._x_rdata(x[0, :], tcl), - "x0_solver": x[0, :], - "tcl": tcl, - "res": self._ys(ts, x, p, tcl, iys) - my, - }[ret], dict( + + stats = dict( ts=ts, x=x, + llh=llh, stats_dyn=stats_dyn, stats_posteq=stats_posteq, ) + if ret == ReturnValue.llh: + output = llh + elif ret == ReturnValue.nllhs: + output = nllhs + elif ret == ReturnValue.x: + output = self._x_rdatas(x, tcl) + elif ret == ReturnValue.x_solver: + output = x + elif ret == ReturnValue.y: + output = self._ys(ts, x, p, tcl, iys) + elif ret == ReturnValue.sigmay: + output = self._sigmays(ts, x, p, tcl, iys) + elif ret == ReturnValue.x0: + output = self._x_rdata(x[0, :], tcl) + elif ret == ReturnValue.x0_solver: + output = x[0, :] + elif ret == ReturnValue.tcl: + output = tcl + elif ret in (ReturnValue.res, ReturnValue.chi2): + obs_trafo = jax.vmap( + lambda y, iy_trafo: jnp.array( + # needs to follow order in amici.jax.petab.SCALE_TO_INT + [y, safe_log(y), safe_log(y) / jnp.log(10)] + ) + .at[iy_trafo] + .get(), + ) + ys_obj = obs_trafo(self._ys(ts, x, p, tcl, iys), iy_trafos) + m_obj = obs_trafo(my, iy_trafos) + if ret == ReturnValue.chi2: + output = jnp.sum( + jnp.square(ys_obj - m_obj) + / jnp.square(self._sigmays(ts, x, p, tcl, iys)) + ) + else: + output = ys_obj - m_obj + else: + raise NotImplementedError(f"Return value {ret} not implemented.") + + return output, stats @eqx.filter_jit def preequilibrate_condition( self, p: jt.Float[jt.Array, "np"], + x_reinit: jt.Float[jt.Array, "*nx"], + mask_reinit: jt.Bool[jt.Array, "*nx"], solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, max_steps: int | jnp.int_, @@ -572,6 +621,8 @@ def preequilibrate_condition( """ # Pre-equilibration x0 = self._x0(p) + if x_reinit.shape[0]: + x0 = jnp.where(mask_reinit, x_reinit, x0) tcl = self._tcl(x0, p) current_x = self._x_solver(x0) current_x, stats_preeq = self._eq( diff --git a/python/sdist/amici/jax/ode_export.py b/python/sdist/amici/jax/ode_export.py index cec5104ded..4329195441 100644 --- a/python/sdist/amici/jax/ode_export.py +++ b/python/sdist/amici/jax/ode_export.py @@ -226,7 +226,7 @@ def _generate_jax_code(self) -> None: # assign named variables from a jax array **_jax_variable_assignments(self.model, sym_names), # tuple of variable names (ids as they are unique) - **_jax_variable_ids(self.model, ("p", "k", "y", "x")), + **_jax_variable_ids(self.model, ("p", "k", "y", "x_rdata")), **{ "MODEL_NAME": self.model_name, # keep track of the API version that the model was generated with so we diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index 0411e5e2df..b5834223fb 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -4,6 +4,7 @@ from collections.abc import Iterable from pathlib import Path + import diffrax import equinox as eqx import jaxtyping as jt @@ -18,7 +19,21 @@ ParameterMappingForCondition, create_parameter_mapping, ) -from amici.jax.model import JAXModel +from amici.jax.model import JAXModel, ReturnValue + +DEFAULT_CONTROLLER_SETTINGS = { + "atol": 1e-8, + "rtol": 1e-8, + "pcoeff": 0.4, + "icoeff": 0.3, + "dcoeff": 0.0, +} + +SCALE_TO_INT = { + petab.LIN: 0, + petab.LOG: 1, + petab.LOG10: 2, +} def jax_unscale( @@ -66,8 +81,16 @@ class JAXProblem(eqx.Module): _parameter_mappings: dict[str, ParameterMappingForCondition] _measurements: dict[ tuple[str, ...], - tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray], + tuple[ + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + ], ] + _petab_measurement_indices: dict[tuple[str, ...], tuple[int, ...]] _petab_problem: petab.Problem def __init__(self, model: JAXModel, petab_problem: petab.Problem): @@ -83,7 +106,9 @@ def __init__(self, model: JAXModel, petab_problem: petab.Problem): scs = petab_problem.get_simulation_conditions_from_measurement_df() self._petab_problem = petab_problem self._parameter_mappings = self._get_parameter_mappings(scs) - self._measurements = self._get_measurements(scs) + self._measurements, self._petab_measurement_indices = ( + self._get_measurements(scs) + ) self.parameters = self._get_nominal_parameter_values() def save(self, directory: Path): @@ -153,9 +178,19 @@ def _get_parameter_mappings( def _get_measurements( self, simulation_conditions: pd.DataFrame - ) -> dict[ - tuple[str, ...], - tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray], + ) -> tuple[ + dict[ + tuple[str, ...], + tuple[ + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + ], + ], + dict[tuple[str, ...], tuple[int, ...]], ]: """ Get measurements for the model based on the provided simulation conditions. @@ -168,6 +203,7 @@ def _get_measurements( post-equilibrium time points; measurements and observable indices). """ measurements = dict() + indices = dict() for _, simulation_condition in simulation_conditions.iterrows(): query = " & ".join( [f"{k} == '{v}'" for k, v in simulation_condition.items()] @@ -176,10 +212,14 @@ def _get_measurements( by=petab.TIME ) - ts = m[petab.TIME].values + ts = m[petab.TIME] ts_preeq = ts[np.isfinite(ts) & (ts == 0)] ts_dyn = ts[np.isfinite(ts) & (ts > 0)] ts_posteq = ts[np.logical_not(np.isfinite(ts))] + index = pd.concat([ts_preeq, ts_dyn, ts_posteq]).index + ts_preeq = ts_preeq.values + ts_dyn = ts_dyn.values + ts_posteq = ts_posteq.values my = m[petab.MEASUREMENT].values iys = np.array( [ @@ -187,6 +227,22 @@ def _get_measurements( for oid in m[petab.OBSERVABLE_ID].values ] ) + if ( + petab.OBSERVABLE_TRANSFORMATION + in self._petab_problem.observable_df + ): + iy_trafos = np.array( + [ + SCALE_TO_INT[ + self._petab_problem.observable_df.loc[ + oid, petab.OBSERVABLE_TRANSFORMATION + ] + ] + for oid in m[petab.OBSERVABLE_ID].values + ] + ) + else: + iy_trafos = np.zeros_like(iys) measurements[tuple(simulation_condition)] = ( ts_preeq, @@ -194,8 +250,10 @@ def _get_measurements( ts_posteq, my, iys, + iy_trafos, ) - return measurements + indices[tuple(simulation_condition)] = tuple(index.tolist()) + return measurements, indices def get_all_simulation_conditions(self) -> tuple[tuple[str, ...], ...]: simulation_conditions = ( @@ -292,6 +350,112 @@ def load_parameters( ) return self._unscale(p, pscale) + def _state_needs_reinitialisation( + self, + simulation_condition: str, + state_id: str, + ) -> bool: + """ + Check if a state needs reinitialisation for a simulation condition. + + :param simulation_condition: + simulation condition to check reinitialisation for + :param state_id: + state id to check reinitialisation for + :return: + True if state needs reinitialisation, False otherwise + """ + if state_id not in self._petab_problem.condition_df: + return False + xval = self._petab_problem.condition_df.loc[ + simulation_condition, state_id + ] + if isinstance(xval, Number) and np.isnan(xval): + return False + return True + + def _state_reinitialisation_value( + self, + simulation_condition: str, + state_id: str, + p: jt.Float[jt.Array, "np"], + ) -> jt.Float[jt.Scalar, ""] | float: # noqa: F722 + """ + Get the reinitialisation value for a state. + + :param simulation_condition: + simulation condition to get reinitialisation value for + :param state_id: + state id to get reinitialisation value for + :param p: + parameters for the simulation condition + :return: + reinitialisation value for the state + """ + if state_id not in self._petab_problem.condition_df: + # no reinitialisation, return dummy value + return 0.0 + xval = self._petab_problem.condition_df.loc[ + simulation_condition, state_id + ] + if isinstance(xval, Number) and np.isnan(xval): + # no reinitialisation, return dummy value + return 0.0 + if isinstance(xval, Number): + # numerical value, return as is + return xval + if xval in self.model.parameter_ids: + # model parameter, return value + return p[self.model.parameter_ids.index(xval)] + if xval in self.parameter_ids: + # estimated PEtab parameter, return unscaled value + return jax_unscale( + self.get_petab_parameter_by_id(xval), + self._petab_problem.parameter_df.loc[ + xval, petab.PARAMETER_SCALE + ], + ) + # only remaining option is nominal value for PEtab parameter + # that is not estimated, return nominal value + return self._petab_problem.parameter_df.loc[xval, petab.NOMINAL_VALUE] + + def load_reinitialisation( + self, + simulation_condition: str, + p: jt.Float[jt.Array, "np"], + ) -> tuple[jt.Bool[jt.Array, "nx"], jt.Float[jt.Array, "nx"]]: # noqa: F821 + """ + Load reinitialisation values and mask for the state vector for a simulation condition. + + :param simulation_condition: + Simulation condition to load reinitialisation for. + :param p: + Parameters for the simulation condition. + :return: + Tuple of reinitialisation masm and value for states. + """ + if not any( + x_id in self._petab_problem.condition_df + for x_id in self.model.state_ids + ): + return jnp.array([]), jnp.array([]) + + mask = jnp.array( + [ + self._state_needs_reinitialisation(simulation_condition, x_id) + for x_id in self.model.state_ids + ] + ) + reinit_x = jnp.array( + [ + self._state_reinitialisation_value( + simulation_condition, x_id, p + ) + for x_id in self.model.state_ids + ] + ) + return mask, reinit_x + def update_parameters(self, p: jt.Float[jt.Array, "np"]) -> "JAXProblem": """ Update parameters for the model. @@ -308,6 +472,7 @@ def run_simulation( controller: diffrax.AbstractStepSizeController, max_steps: jnp.int_, x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]), # noqa: F821, F722 + ret: ReturnValue = ReturnValue.llh, ) -> tuple[jnp.float_, dict]: """ Run a simulation for a given simulation condition. @@ -322,25 +487,36 @@ def run_simulation( Maximum number of steps to take during simulation :param x_preeq: Pre-equilibration state if available + :param ret: + which output to return. See :class:`ReturnValue` for available options. :return: - Tuple of log-likelihood and simulation statistics + Tuple of output value and simulation statistics """ - ts_preeq, ts_dyn, ts_posteq, my, iys = self._measurements[ + ts_preeq, ts_dyn, ts_posteq, my, iys, iy_trafos = self._measurements[ simulation_condition ] p = self.load_parameters(simulation_condition[0]) + mask_reinit, x_reinit = self.load_reinitialisation( + simulation_condition[0], p + ) return self.model.simulate_condition( - p=p, + p=eqx.debug.backward_nan(p), ts_init=jax.lax.stop_gradient(jnp.array(ts_preeq)), ts_dyn=jax.lax.stop_gradient(jnp.array(ts_dyn)), ts_posteq=jax.lax.stop_gradient(jnp.array(ts_posteq)), my=jax.lax.stop_gradient(jnp.array(my)), iys=jax.lax.stop_gradient(jnp.array(iys)), + iy_trafos=jax.lax.stop_gradient(jnp.array(iy_trafos)), x_preeq=x_preeq, + mask_reinit=jax.lax.stop_gradient(mask_reinit), + x_reinit=x_reinit, solver=solver, controller=controller, max_steps=max_steps, - adjoint=diffrax.RecursiveCheckpointAdjoint(), + adjoint=diffrax.RecursiveCheckpointAdjoint() + if ret in (ReturnValue.llh, ReturnValue.chi2) + else diffrax.DirectAdjoint(), + ret=ret, ) def run_preequilibration( @@ -365,8 +541,13 @@ def run_preequilibration( Pre-equilibration state """ p = self.load_parameters(simulation_condition) + mask_reinit, x_reinit = self.load_reinitialisation( + simulation_condition, p + ) return self.model.preequilibrate_condition( - p=p, + p=eqx.debug.backward_nan(p), + mask_reinit=mask_reinit, + x_reinit=x_reinit, solver=solver, controller=controller, max_steps=max_steps, @@ -378,13 +559,10 @@ def run_simulations( simulation_conditions: Iterable[tuple[str, ...]] | None = None, solver: diffrax.AbstractSolver = diffrax.Kvaerno5(), controller: diffrax.AbstractStepSizeController = diffrax.PIDController( - rtol=1e-8, - atol=1e-8, - pcoeff=0.4, - icoeff=0.3, - dcoeff=0.0, + **DEFAULT_CONTROLLER_SETTINGS ), max_steps: int = 2**10, + ret: ReturnValue | str = ReturnValue.llh, ): """ Run simulations for a problem. @@ -399,9 +577,14 @@ def run_simulations( Step size controller to use for simulation. :param max_steps: Maximum number of steps to take during simulation. + :param ret: + which output to return. See :class:`ReturnValue` for available options. :return: - Overall negative log-likelihood and condition specific results and statistics. + Overall output value and condition specific results and statistics. """ + if isinstance(ret, str): + ret = ReturnValue[ret] + if simulation_conditions is None: simulation_conditions = problem.get_all_simulation_conditions() @@ -418,10 +601,86 @@ def run_simulations( controller, max_steps, preeqs.get(sc[1])[0] if len(sc) > 1 else jnp.array([]), + ret=ret, ) for sc in simulation_conditions } - return sum(llh for llh, _ in results.values()), { + stats = { sc: res[1] | preeqs[sc[1]][1] if len(sc) > 1 else res[1] for sc, res in results.items() } + if ret in (ReturnValue.llh, ReturnValue.chi2): + output = sum(r for r, _ in results.values()) + else: + output = {sc: res[0] for sc, res in results.items()} + + return output, stats + + +def petab_simulate( + problem: JAXProblem, + solver: diffrax.AbstractSolver = diffrax.Kvaerno5(), + controller: diffrax.AbstractStepSizeController = diffrax.PIDController( + **DEFAULT_CONTROLLER_SETTINGS + ), + max_steps: int = 2**10, +): + """ + Run simulations for a problem and return the results as a petab simulation dataframe. + + :param problem: + Problem to run simulations for. + :param solver: + ODE solver to use for simulation. + :param controller: + Step size controller to use for simulation. + :param max_steps: + Maximum number of steps to take during simulation. + :return: + petab simulation dataframe. + """ + y, r = run_simulations( + problem, + solver=solver, + controller=controller, + max_steps=max_steps, + ret=ReturnValue.y, + ) + dfs = [] + for sc, ys in y.items(): + obs = [ + problem.model.observable_ids[io] + for io in problem._measurements[sc][4] + ] + t = jnp.concat(problem._measurements[sc][:2]) + df_sc = pd.DataFrame( + { + petab.SIMULATION: ys, + petab.TIME: t, + petab.OBSERVABLE_ID: obs, + petab.SIMULATION_CONDITION_ID: [sc[0]] * len(t), + }, + index=problem._petab_measurement_indices[sc], + ) + if ( + petab.OBSERVABLE_PARAMETERS + in problem._petab_problem.measurement_df + ): + df_sc[petab.OBSERVABLE_PARAMETERS] = ( + problem._petab_problem.measurement_df.query( + f"{petab.SIMULATION_CONDITION_ID} == '{sc[0]}'" + )[petab.OBSERVABLE_PARAMETERS] + ) + if petab.NOISE_PARAMETERS in problem._petab_problem.measurement_df: + df_sc[petab.NOISE_PARAMETERS] = ( + problem._petab_problem.measurement_df.query( + f"{petab.SIMULATION_CONDITION_ID} == '{sc[0]}'" + )[petab.NOISE_PARAMETERS] + ) + if ( + petab.PREEQUILIBRATION_CONDITION_ID + in problem._petab_problem.measurement_df + ): + df_sc[petab.PREEQUILIBRATION_CONDITION_ID] = sc[1] + dfs.append(df_sc) + return pd.concat(dfs).sort_index() diff --git a/python/sdist/amici/petab/sbml_import.py b/python/sdist/amici/petab/sbml_import.py index e4e7efd7fc..e605a9cc80 100644 --- a/python/sdist/amici/petab/sbml_import.py +++ b/python/sdist/amici/petab/sbml_import.py @@ -348,11 +348,14 @@ def import_model_sbml( _workaround_observable_parameters( observables, sigmas, sbml_model, output_parameter_defaults ) - fixed_parameters = _workaround_initial_states( - petab_problem=petab_problem, - sbml_model=sbml_model, - **kwargs, - ) + if not jax: + fixed_parameters = _workaround_initial_states( + petab_problem=petab_problem, + sbml_model=sbml_model, + **kwargs, + ) + else: + fixed_parameters = [] fixed_parameters.extend( _get_fixed_parameters_sbml( diff --git a/python/sdist/pyproject.toml b/python/sdist/pyproject.toml index 6441ac3300..b62903240e 100644 --- a/python/sdist/pyproject.toml +++ b/python/sdist/pyproject.toml @@ -83,7 +83,7 @@ examples = [ "scipy", ] jax = [ - "jax>=0.4.34", + "jax>=0.4.34,<0.4.36", "jaxlib>=0.4.34", "diffrax>=0.6.0", "jaxtyping>=0.2.34", diff --git a/python/tests/test_jax.py b/python/tests/test_jax.py index ce7018e078..ef9cbde576 100644 --- a/python/tests/test_jax.py +++ b/python/tests/test_jax.py @@ -15,7 +15,7 @@ from amici.pysb_import import pysb2amici, pysb2jax from amici.testing import TemporaryDirectoryWinSafe, skip_on_valgrind from amici.petab.petab_import import import_petab_problem -from amici.jax import JAXProblem +from amici.jax import JAXProblem, ReturnValue from numpy.testing import assert_allclose from test_petab_objective import lotka_volterra # noqa: F401 @@ -177,6 +177,7 @@ def check_fields_jax( my = my.flatten() ts = ts.flatten() iys = iys.flatten() + iy_trafos = np.zeros_like(iys) ts_init = ts[ts == 0] ts_dyn = ts[ts > 0] @@ -194,6 +195,7 @@ def check_fields_jax( "ts_posteq": jnp.array(ts_posteq), "my": jnp.array(my), "iys": jnp.array(iys), + "iy_trafos": jnp.array(iy_trafos), "x_preeq": jnp.array([]), "solver": diffrax.Kvaerno5(), "controller": diffrax.PIDController(atol=ATOL_SIM, rtol=RTOL_SIM), @@ -206,7 +208,7 @@ def check_fields_jax( okwargs = kwargs | { "adjoint": diffrax.DirectAdjoint(), "max_steps": 2**8, - "ret": output, + "ret": ReturnValue[output], } if sensi_order == amici.SensitivityOrder.none: r_jax[output] = fun(p, **okwargs)[0] diff --git a/tests/benchmark-models/test_petab_benchmark.py b/tests/benchmark-models/test_petab_benchmark.py index 6a388f7493..7c70015a8c 100644 --- a/tests/benchmark-models/test_petab_benchmark.py +++ b/tests/benchmark-models/test_petab_benchmark.py @@ -340,12 +340,13 @@ def test_jax_llh(benchmark_problem): [problem_parameters[pid] for pid in jax_problem.parameter_ids] ), ) + llh_jax, _ = beartype(run_simulations)(jax_problem) if problem_id in problems_for_gradient_check: - (llh_jax, _), sllh_jax = eqx.filter_jit( - eqx.filter_value_and_grad(run_simulations, has_aux=True) + (llh_jax, _), sllh_jax = eqx.filter_value_and_grad( + run_simulations, has_aux=True )(jax_problem) else: - llh_jax, _ = beartype(eqx.filter_jit(run_simulations))(jax_problem) + llh_jax, _ = beartype(run_simulations)(jax_problem) np.testing.assert_allclose( llh_jax, diff --git a/tests/petab_test_suite/conftest.py b/tests/petab_test_suite/conftest.py index 2e1c6d3cea..b51f240ffd 100644 --- a/tests/petab_test_suite/conftest.py +++ b/tests/petab_test_suite/conftest.py @@ -60,7 +60,7 @@ def pytest_generate_tests(metafunc): if metafunc.config.getoption("--only-sbml"): argvalues = [ - (case, "sbml", version) + (case, "sbml", version, False) for version in ("v1.0.0", "v2.0.0") for case in ( test_numbers @@ -70,7 +70,7 @@ def pytest_generate_tests(metafunc): ] elif metafunc.config.getoption("--only-pysb"): argvalues = [ - (case, "pysb", "v2.0.0") + (case, "pysb", "v2.0.0", False) for case in ( test_numbers if test_numbers @@ -81,8 +81,10 @@ def pytest_generate_tests(metafunc): argvalues = [] for version in ("v1.0.0", "v2.0.0"): for format in ("sbml", "pysb"): - argvalues.extend( - (case, format, version) - for case in test_numbers or get_cases(format, version) - ) - metafunc.parametrize("case,model_type,version", argvalues) + for jax in (True, False): + argvalues.extend( + (case, format, version, jax) + for case in test_numbers + or get_cases(format, version) + ) + metafunc.parametrize("case,model_type,version,jax", argvalues) diff --git a/tests/petab_test_suite/test_petab_suite.py b/tests/petab_test_suite/test_petab_suite.py index f5bf354cd3..5fe61adcf2 100755 --- a/tests/petab_test_suite/test_petab_suite.py +++ b/tests/petab_test_suite/test_petab_suite.py @@ -23,10 +23,10 @@ logger.addHandler(stream_handler) -def test_case(case, model_type, version): +def test_case(case, model_type, version, jax): """Wrapper for _test_case for handling test outcomes""" try: - _test_case(case, model_type, version) + _test_case(case, model_type, version, jax) except Exception as e: if isinstance( e, NotImplementedError @@ -41,10 +41,10 @@ def test_case(case, model_type, version): raise e -def _test_case(case, model_type, version): +def _test_case(case, model_type, version, jax): """Run a single PEtab test suite case""" case = petabtests.test_id_str(case) - logger.debug(f"Case {case} [{model_type}] [{version}]") + logger.debug(f"Case {case} [{model_type}] [{version}] [{jax}]") # load case_dir = petabtests.get_case_dir(case, model_type, version) @@ -57,34 +57,46 @@ def _test_case(case, model_type, version): model_name = ( f"petab_{model_type}_test_case_{case}" f"_{version.replace('.', '_')}" ) - model_output_dir = f"amici_models/{model_name}" + model_output_dir = f"amici_models/{model_name}" + ("_jax" if jax else "") model = import_petab_problem( petab_problem=problem, model_output_dir=model_output_dir, model_name=model_name, compile_=True, + jax=jax, ) - solver = model.getSolver() - solver.setSteadyStateToleranceFactor(1.0) - problem_parameters = dict( - zip(problem.x_free_ids, problem.x_nominal_free, strict=True) - ) + if jax: + from amici.jax import JAXProblem, run_simulations, petab_simulate + + jax_problem = JAXProblem(model, problem) + llh, ret = run_simulations(jax_problem) + chi2, _ = run_simulations(jax_problem, ret="chi2") + simulation_df = petab_simulate(jax_problem) + simulation_df.rename( + columns={petab.SIMULATION: petab.MEASUREMENT}, inplace=True + ) + else: + solver = model.getSolver() + solver.setSteadyStateToleranceFactor(1.0) + problem_parameters = dict( + zip(problem.x_free_ids, problem.x_nominal_free, strict=True) + ) - # simulate - ret = simulate_petab( - problem, - model, - problem_parameters=problem_parameters, - solver=solver, - log_level=logging.DEBUG, - ) + # simulate + ret = simulate_petab( + problem, + model, + problem_parameters=problem_parameters, + solver=solver, + log_level=logging.DEBUG, + ) - rdatas = ret["rdatas"] - chi2 = sum(rdata["chi2"] for rdata in rdatas) - llh = ret["llh"] - simulation_df = rdatas_to_measurement_df( - rdatas, model, problem.measurement_df - ) + rdatas = ret["rdatas"] + chi2 = sum(rdata["chi2"] for rdata in rdatas) + llh = ret["llh"] + simulation_df = rdatas_to_measurement_df( + rdatas, model, problem.measurement_df + ) petab.check_measurement_df(simulation_df, problem.observable_df) simulation_df = simulation_df.rename( columns={petab.MEASUREMENT: petab.SIMULATION} @@ -142,7 +154,10 @@ def _test_case(case, model_type, version): f"LLH: simulated: {llh}, expected: {gt_llh}, " f"match = {llhs_match}", ) - check_derivatives(problem, model, solver, problem_parameters) + if jax: + pass # skip derivative checks for now + else: + check_derivatives(problem, model, solver, problem_parameters) if not all([llhs_match, simulations_match]) or not chi2s_match: logger.error(f"Case {case} failed.") @@ -196,18 +211,19 @@ def run(): n_skipped = 0 n_total = 0 for version in ("v1.0.0", "v2.0.0"): - cases = petabtests.get_cases("sbml", version=version) - n_total += len(cases) - for case in cases: - try: - test_case(case, "sbml", version=version) - n_success += 1 - except Skipped: - n_skipped += 1 - except Exception as e: - # run all despite failures - logger.error(f"Case {case} failed.") - logger.error(e) + for jax in (False, True): + cases = petabtests.get_cases("sbml", version=version) + n_total += len(cases) + for case in cases: + try: + test_case(case, "sbml", version=version, jax=jax) + n_success += 1 + except Skipped: + n_skipped += 1 + except Exception as e: + # run all despite failures + logger.error(f"Case {case} failed.") + logger.error(e) logger.info(f"{n_success} / {n_total} successful, " f"{n_skipped} skipped") if n_success != len(cases):