diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index 6c41b8b179..98e123b5f0 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -4,6 +4,7 @@ from abc import abstractmethod from pathlib import Path +import enum import diffrax import equinox as eqx @@ -12,6 +13,20 @@ import jaxtyping as jt +class ReturnValue(enum.Enum): + llh = "log-likelihood" + nllhs = "pointwise negative log-likelihood" + x0 = "full initial state vector" + x0_solver = "reduced initial state vector" + x = "full state vector" + x_solver = "reduced state vector" + y = "observables" + sigmay = "standard deviations of the observables" + tcl = "total values for conservation laws" + res = "residuals" + chi2 = "sum(((observed - simulated) / sigma ) ** 2)" + + class JAXModel(eqx.Module): """ JAXModel provides an abstract base class for a JAX-based implementation of an AMICI model. The class implements @@ -440,7 +455,7 @@ def simulate_condition( x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]), mask_reinit: jt.Bool[jt.Array, "*nx"] = jnp.array([]), x_reinit: jt.Float[jt.Array, "*nx"] = jnp.array([]), - ret: str = "llh", + ret: ReturnValue = ReturnValue.llh, ) -> tuple[jt.Float[jt.Array, "nt *nx"] | jnp.float_, dict]: r""" Simulate a condition. @@ -478,18 +493,7 @@ def simulate_condition( :param max_steps: maximum number of solver steps :param ret: - which output to return. Valid values are - - `llh`: log-likelihood (default) - - `nllhs`: negative log-likelihood at each time point - - `x0`: full initial state vector (after pre-equilibration) - - `x0_solver`: reduced initial state vector (after pre-equilibration) - - `x`: full state vector - - `x_solver`: reduced state vector - - `y`: observables - - `sigmay`: standard deviations of the observables - - `tcl`: total values for conservation laws (at final timepoint) - - `res`: residuals (observed - simulated) - - 'chi2': sum((observed - simulated) ** 2 / sigma ** 2) + which output to return. See :class:`ReturnValue` for available options. :return: output according to `ret` and statistics """ @@ -542,36 +546,54 @@ def simulate_condition( nllhs = self._nllhs(ts, x, p, tcl, my, iys) llh = -jnp.sum(nllhs) - obs_trafo = jax.vmap( - lambda y, iy_trafo: jnp.array( - [y, safe_log(y), safe_log(y) / jnp.log(10)] - ) - .at[iy_trafo] - .get(), - ) - ys_obj = obs_trafo(self._ys(ts, x, p, tcl, iys), iy_trafos) - m_obj = obs_trafo(my, iy_trafos) - return { - "llh": llh, - "nllhs": nllhs, - "x": self._x_rdatas(x, tcl), - "x_solver": x, - "y": self._ys(ts, x, p, tcl, iys), - "sigmay": self._sigmays(ts, x, p, tcl, iys), - "x0": self._x_rdata(x[0, :], tcl), - "x0_solver": x[0, :], - "tcl": tcl, - "res": self._ys(ts, x, p, tcl, iys) - my, - "chi2": jnp.sum( - jnp.square(ys_obj - m_obj) - / jnp.square(self._sigmays(ts, x, p, tcl, iys)) - ), - }[ret], dict( + + stats = dict( ts=ts, x=x, + llh=llh, stats_dyn=stats_dyn, stats_posteq=stats_posteq, ) + if ret == ReturnValue.llh: + output = llh + elif ret == ReturnValue.nllhs: + output = nllhs + elif ret == ReturnValue.x: + output = self._x_rdatas(x, tcl) + elif ret == ReturnValue.x_solver: + output = x + elif ret == ReturnValue.y: + output = self._ys(ts, x, p, tcl, iys) + elif ret == ReturnValue.sigmay: + output = self._sigmays(ts, x, p, tcl, iys) + elif ret == ReturnValue.x0: + output = self._x_rdata(x[0, :], tcl) + elif ret == ReturnValue.x0_solver: + output = x[0, :] + elif ret == ReturnValue.tcl: + output = tcl + elif ret in (ReturnValue.res, ReturnValue.chi2): + obs_trafo = jax.vmap( + lambda y, iy_trafo: jnp.array( + # needs to follow order in amici.jax.petab.SCALE_TO_INT + [y, safe_log(y), safe_log(y) / jnp.log(10)] + ) + .at[iy_trafo] + .get(), + ) + ys_obj = obs_trafo(self._ys(ts, x, p, tcl, iys), iy_trafos) + m_obj = obs_trafo(my, iy_trafos) + if ret == ReturnValue.chi2: + output = jnp.sum( + jnp.square(ys_obj - m_obj) + / jnp.square(self._sigmays(ts, x, p, tcl, iys)) + ) + else: + output = ys_obj - m_obj + else: + raise NotImplementedError(f"Return value {ret} not implemented.") + + return output, stats @eqx.filter_jit def preequilibrate_condition( diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index 850b11a02b..c1f3da5c2e 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -4,6 +4,7 @@ from collections.abc import Iterable from pathlib import Path + import diffrax import equinox as eqx import jaxtyping as jt @@ -18,7 +19,7 @@ ParameterMappingForCondition, create_parameter_mapping, ) -from amici.jax.model import JAXModel +from amici.jax.model import JAXModel, ReturnValue DEFAULT_CONTROLLER_SETTINGS = { "atol": 1e-8, @@ -28,6 +29,12 @@ "dcoeff": 0.0, } +SCALE_TO_INT = { + petab.LIN: 0, + petab.LOG: 1, + petab.LOG10: 2, +} + def jax_unscale( parameter: jnp.float_, @@ -171,16 +178,19 @@ def _get_parameter_mappings( def _get_measurements( self, simulation_conditions: pd.DataFrame - ) -> dict[ - tuple[str, ...], - tuple[ - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, + ) -> tuple[ + dict[ + tuple[str, ...], + tuple[ + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + ], ], + dict[tuple[str, ...], tuple[int, ...]], ]: """ Get measurements for the model based on the provided simulation conditions. @@ -221,14 +231,9 @@ def _get_measurements( petab.OBSERVABLE_TRANSFORMATION in self._petab_problem.observable_df ): - trafo_map = { - petab.LIN: 0, - petab.LOG: 1, - petab.LOG10: 2, - } iy_trafos = np.array( [ - trafo_map[ + SCALE_TO_INT[ self._petab_problem.observable_df.loc[ oid, petab.OBSERVABLE_TRANSFORMATION ] @@ -345,6 +350,75 @@ def load_parameters( ) return self._unscale(p, pscale) + def _state_needs_reinitialisation( + self, + simulation_condition: str, + state_id: str, + ) -> bool: + """ + Check if a state needs reinitialisation for a simulation condition. + + :param simulation_condition: + simulation condition to check reinitialisation for + :param state_id: + state id to check reinitialisation for + :return: + True if state needs reinitialisation, False otherwise + """ + if state_id not in self._petab_problem.condition_df: + return False + xval = self._petab_problem.condition_df.loc[ + simulation_condition, state_id + ] + if isinstance(xval, Number) and np.isnan(xval): + return False + return True + + def _state_reinitialisation_value( + self, + simulation_condition: str, + state_id: str, + p: jt.Float[jt.Array, "np"], + ) -> jt.Float[jt.Scalar, ""] | float: # noqa: F722 + """ + Get the reinitialisation value for a state. + + :param simulation_condition: + simulation condition to get reinitialisation value for + :param state_id: + state id to get reinitialisation value for + :param p: + parameters for the simulation condition + :return: + reinitialisation value for the state + """ + if state_id not in self._petab_problem.condition_df: + # no reinitialisation, return dummy value + return 0.0 + xval = self._petab_problem.condition_df.loc[ + simulation_condition, state_id + ] + if isinstance(xval, Number) and np.isnan(xval): + # no reinitialisation, return dummy value + return 0.0 + if isinstance(xval, Number): + # numerical value, return as is + return xval + if xval in self.model.parameter_ids: + # model parameter, return value + return p[self.model.parameter_ids.index(xval)] + if xval in self.parameter_ids: + # estimated PEtab parameter, return unscaled value + return jax_unscale( + self.get_petab_parameter_by_id(xval), + self._petab_problem.parameter_df.loc[ + xval, petab.PARAMETER_SCALE + ], + ) + # only remaining option is nominal value for PEtab parameter + # that is not estimated, return nominal value + return self._petab_problem.parameter_df.loc[xval, petab.NOMINAL_VALUE] + def load_reinitialisation( self, simulation_condition: str, @@ -360,60 +434,24 @@ def load_reinitialisation( :return: Tuple of reinitialisation masm and value for states. """ - mask = jax.lax.stop_gradient( - jnp.array( - [ - xname in self._petab_problem.condition_df - and not ( - isinstance( - self._petab_problem.condition_df.loc[ - simulation_condition, xname - ], - Number, - ) - and np.isnan( - self._petab_problem.condition_df.loc[ - simulation_condition, xname - ] - ) - ) - for xname in self.model.state_ids - ] - ) + if not any( + x_id in self._petab_problem.condition_df + for x_id in self.model.state_ids + ): + return jnp.array([]), jnp.array([]) + + mask = jnp.array( + [ + self._state_needs_reinitialisation(simulation_condition, x_id) + for x_id in self.model.state_ids + ] ) reinit_x = jnp.array( [ - 0 - if xname not in self._petab_problem.condition_df - or ( - isinstance( - xval := self._petab_problem.condition_df.loc[ - simulation_condition, xname - ], - Number, - ) - and np.isnan(xval) - ) - else xval - if isinstance( - xval := self._petab_problem.condition_df.loc[ - simulation_condition, xname - ], - Number, - ) - else p[self.model.parameter_ids.index(xval)] - if xval in self.model.parameter_ids - else jax_unscale( - self.get_petab_parameter_by_id(xval), - self._petab_problem.parameter_df.loc[ - xval, petab.PARAMETER_SCALE - ], + self._state_reinitialisation_value( + simulation_condition, x_id, p ) - if xval in self.parameter_ids - else self._petab_problem.parameter_df.loc[ - xval, petab.NOMINAL_VALUE - ] - for xname in self.model.state_ids + for x_id in self.model.state_ids ] ) return mask, reinit_x @@ -434,7 +472,7 @@ def run_simulation( controller: diffrax.AbstractStepSizeController, max_steps: jnp.int_, x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]), # noqa: F821, F722 - ret: str = "llh", + ret: ReturnValue = ReturnValue.llh, ) -> tuple[jnp.float_, dict]: """ Run a simulation for a given simulation condition. @@ -450,18 +488,7 @@ def run_simulation( :param x_preeq: Pre-equilibration state if available :param ret: - which output to return. Valid values are - - `llh`: log-likelihood (default) - - `nllhs`: negative log-likelihood at each time point - - `x0`: full initial state vector (after pre-equilibration) - - `x0_solver`: reduced initial state vector (after pre-equilibration) - - `x`: full state vector - - `x_solver`: reduced state vector - - `y`: observables - - `sigmay`: standard deviations of the observables - - `tcl`: total values for conservation laws (at final timepoint) - - `res`: residuals (observed - simulated) - - 'chi2': sum((observed - simulated) ** 2 / sigma ** 2) + which output to return. See :class:`ReturnValue` for available options. :return: Tuple of output value and simulation statistics """ @@ -473,7 +500,7 @@ def run_simulation( simulation_condition[0], p ) return self.model.simulate_condition( - p=p, + p=eqx.debug.backward_nan(p), ts_init=jax.lax.stop_gradient(jnp.array(ts_preeq)), ts_dyn=jax.lax.stop_gradient(jnp.array(ts_dyn)), ts_posteq=jax.lax.stop_gradient(jnp.array(ts_posteq)), @@ -481,13 +508,13 @@ def run_simulation( iys=jax.lax.stop_gradient(jnp.array(iys)), iy_trafos=jax.lax.stop_gradient(jnp.array(iy_trafos)), x_preeq=x_preeq, - mask_reinit=mask_reinit, + mask_reinit=jax.lax.stop_gradient(mask_reinit), x_reinit=x_reinit, solver=solver, controller=controller, max_steps=max_steps, adjoint=diffrax.RecursiveCheckpointAdjoint() - if ret in ("llh", "chi2") + if ret in (ReturnValue.llh, ReturnValue.chi2) else diffrax.DirectAdjoint(), ret=ret, ) @@ -518,7 +545,7 @@ def run_preequilibration( simulation_condition, p ) return self.model.preequilibrate_condition( - p=p, + p=eqx.debug.backward_nan(p), mask_reinit=mask_reinit, x_reinit=x_reinit, solver=solver, @@ -535,7 +562,7 @@ def run_simulations( **DEFAULT_CONTROLLER_SETTINGS ), max_steps: int = 2**10, - ret: str = "llh", + ret: ReturnValue = ReturnValue.llh, ): """ Run simulations for a problem. @@ -551,18 +578,7 @@ def run_simulations( :param max_steps: Maximum number of steps to take during simulation. :param ret: - which output to return. Valid values are - - `llh`: log-likelihood (default) - - `nllhs`: negative log-likelihood at each time point - - `x0`: full initial state vector (after pre-equilibration) - - `x0_solver`: reduced initial state vector (after pre-equilibration) - - `x`: full state vector - - `x_solver`: reduced state vector - - `y`: observables - - `sigmay`: standard deviations of the observables - - `tcl`: total values for conservation laws (at final timepoint) - - `res`: residuals (observed - simulated) - - 'chi2': sum((observed - simulated) ** 2 / sigma ** 2) + which output to return. See :class:`ReturnValue` for available options. :return: Overall output value and condition specific results and statistics. """ @@ -590,7 +606,7 @@ def run_simulations( sc: res[1] | preeqs[sc[1]][1] if len(sc) > 1 else res[1] for sc, res in results.items() } - if ret in ("llh", "chi2"): + if ret in (ReturnValue.llh, ReturnValue.chi2): output = sum(r for r, _ in results.values()) else: output = {sc: res[0] for sc, res in results.items()} @@ -625,7 +641,7 @@ def petab_simulate( solver=solver, controller=controller, max_steps=max_steps, - ret="y", + ret=ReturnValue.y, ) dfs = [] for sc, ys in y.items(): diff --git a/tests/benchmark-models/test_petab_benchmark.py b/tests/benchmark-models/test_petab_benchmark.py index 6a388f7493..7c70015a8c 100644 --- a/tests/benchmark-models/test_petab_benchmark.py +++ b/tests/benchmark-models/test_petab_benchmark.py @@ -340,12 +340,13 @@ def test_jax_llh(benchmark_problem): [problem_parameters[pid] for pid in jax_problem.parameter_ids] ), ) + llh_jax, _ = beartype(run_simulations)(jax_problem) if problem_id in problems_for_gradient_check: - (llh_jax, _), sllh_jax = eqx.filter_jit( - eqx.filter_value_and_grad(run_simulations, has_aux=True) + (llh_jax, _), sllh_jax = eqx.filter_value_and_grad( + run_simulations, has_aux=True )(jax_problem) else: - llh_jax, _ = beartype(eqx.filter_jit(run_simulations))(jax_problem) + llh_jax, _ = beartype(run_simulations)(jax_problem) np.testing.assert_allclose( llh_jax,