diff --git a/python/sdist/amici/jax.py b/python/sdist/amici/jax.py index fc16a533e1..6161759ebd 100644 --- a/python/sdist/amici/jax.py +++ b/python/sdist/amici/jax.py @@ -8,6 +8,7 @@ import jax.numpy as jnp import numpy as np import jax +import pandas as pd import petab.v1 as petab import amici @@ -24,66 +25,34 @@ class JAXModel(eqx.Module): - _unscale_funs = { - amici.ParameterScaling.none: lambda x: x, - amici.ParameterScaling.ln: lambda x: jnp.exp(x), - amici.ParameterScaling.log10: lambda x: jnp.power(10, x), - } solver: diffrax.AbstractSolver controller: diffrax.AbstractStepSizeController - atol: float - rtol: float - pcoeff: float - icoeff: float - dcoeff: float maxsteps: int parameters: jnp.ndarray - parameter_mappings: dict[tuple[str], ParameterMappingForCondition] - term: diffrax.ODETerm + parameter_mappings: dict[tuple[str], ParameterMappingForCondition] | None + measurements: dict[tuple[str], pd.DataFrame] | None petab_problem: petab.Problem | None def __init__(self): self.solver = diffrax.Kvaerno5() - self.atol: float = 1e-8 - self.rtol: float = 1e-8 - self.pcoeff: float = 0.4 - self.icoeff: float = 0.3 - self.dcoeff: float = 0.0 self.maxsteps: int = 2**14 self.controller = diffrax.PIDController( - rtol=self.rtol, - atol=self.atol, - pcoeff=self.pcoeff, - icoeff=self.icoeff, - dcoeff=self.dcoeff, + rtol=1e-8, + atol=1e-8, + pcoeff=0.4, + icoeff=0.3, + dcoeff=0.0, ) - self.term = diffrax.ODETerm(self.xdot) self.petab_problem = None self.parameter_mappings = None + self.measurements = None self.parameters = jnp.array([]) - def set_petab_problem(self, petab_problem: petab.Problem) -> "JAXModel": - """ - Set the PEtab problem for the model and updates parameters to the nominal values. - :param petab_problem: - Petab problem to set. - :return: JAXModel instance - """ - - is_leaf = lambda x: x is None if self.petab_problem is None else None # noqa: E731 - model = eqx.tree_at( - lambda x: x.petab_problem, - self, - petab_problem, - is_leaf=is_leaf, - ) - - simulation_conditions = ( - petab_problem.get_simulation_conditions_from_measurement_df() - ) - + def _set_parameter_mappings( + self, simulation_conditions: pd.DataFrame + ) -> "JAXModel": mappings = create_parameter_mapping( - petab_problem=petab_problem, + petab_problem=self.petab_problem, simulation_conditions=simulation_conditions, scaled_parameters=False, amici_model=self, @@ -95,31 +64,81 @@ def set_petab_problem(self, petab_problem: petab.Problem) -> "JAXModel": simulation_conditions.iterrows(), mappings ) } + is_leaf = ( # noqa: E731 lambda x: x is None if self.parameter_mappings is None else None ) - model = eqx.tree_at( + return eqx.tree_at( lambda x: x.parameter_mappings, - model, + self, parameter_mappings, is_leaf=is_leaf, ) + def _set_measurements( + self, simulation_conditions: pd.DataFrame + ) -> "JAXModel": + measurements = dict() + for _, simulation_condition in simulation_conditions.iterrows(): + measurements_df = self.petab_problem.measurement_df + for k, v in simulation_condition.items(): + measurements_df = measurements_df.query(f"{k} == '{v}'") + + ts = _get_timepoints_with_replicates(measurements_df) + my = _get_measurements_and_sigmas( + measurements_df, ts, self.observable_ids + )[0].flatten() + measurements[tuple(simulation_condition)] = np.array(ts), my + is_leaf = ( # noqa: E731 + lambda x: x is None if self.measurements is None else None + ) + return eqx.tree_at( + lambda x: x.measurements, + self, + measurements, + is_leaf=is_leaf, + ) + + def _set_nominal_parameter_values(self) -> "JAXModel": nominal_values = jnp.array( [ petab.scale( - model.petab_problem.parameter_df.loc[ + self.petab_problem.parameter_df.loc[ pval, petab.NOMINAL_VALUE ], - model.petab_problem.parameter_df.loc[ + self.petab_problem.parameter_df.loc[ pval, petab.PARAMETER_SCALE ], ) - for pval in model.petab_parameter_ids() + for pval in self.petab_parameter_ids() ] ) + return eqx.tree_at(lambda x: x.parameters, self, nominal_values) - return eqx.tree_at(lambda x: x.parameters, model, nominal_values) + def _set_petab_problem(self, petab_problem: petab.Problem) -> "JAXModel": + is_leaf = lambda x: x is None if self.petab_problem is None else None # noqa: E731 + return eqx.tree_at( + lambda x: x.petab_problem, + self, + petab_problem, + is_leaf=is_leaf, + ) + + def set_petab_problem(self, petab_problem: petab.Problem) -> "JAXModel": + """ + Set the PEtab problem for the model and updates parameters to the nominal values. + :param petab_problem: + Petab problem to set. + :return: JAXModel instance + """ + + model = self._set_petab_problem(petab_problem) + simulation_conditions = ( + petab_problem.get_simulation_conditions_from_measurement_df() + ) + model = model._set_parameter_mappings(simulation_conditions) + model = model._set_measurements(simulation_conditions) + return model._set_nominal_parameter_values() @staticmethod @abstractmethod @@ -216,7 +235,7 @@ def _posteq(self, p, k, x, tcl): def _eq(self, p, k, tcl, x0): sol = diffrax.diffeqsolve( - self.term, + diffrax.ODETerm(self.xdot), self.solver, args=(p, k, tcl), t0=0.0, @@ -232,7 +251,7 @@ def _eq(self, p, k, tcl, x0): def _solve(self, ts, p, k, x0, checkpointed): tcl = self.tcl(x0, p, k) sol = diffrax.diffeqsolve( - self.term, + diffrax.ODETerm(self.xdot), self.solver, args=(p, k, tcl), t0=0.0, @@ -264,15 +283,15 @@ def _loss(self, obs: jnp.ndarray, sigmay: jnp.ndarray, my: np.ndarray): loss_fun = jax.vmap(self.Jy, in_axes=(0, 0, 0)) return -jnp.sum(loss_fun(obs, my, sigmay)) - def _run( + def run_condition( self, - ts: np.ndarray, - ts_dyn: np.ndarray, + ts: jnp.ndarray, + ts_dyn: jnp.ndarray, p: jnp.ndarray, - k: np.ndarray, - k_preeq: np.ndarray, + k: jnp.ndarray, + k_preeq: jnp.ndarray, my: jnp.ndarray, - pscale: np.ndarray, + pscale: jnp.ndarray, checkpointed=True, dynamic="true", ): @@ -323,55 +342,55 @@ def _run( return llh, (x_rdata, obs, stats) @eqx.filter_jit - def run( + def _fun( self, - ts: np.ndarray, - ts_dyn: np.ndarray, + ts: jnp.ndarray, + ts_dyn: jnp.ndarray, p: jnp.ndarray, - k: np.ndarray, - k_preeq: np.ndarray, - my: np.ndarray, - pscale: np.ndarray, + k: jnp.ndarray, + k_preeq: jnp.ndarray, + my: jnp.ndarray, + pscale: jnp.ndarray, dynamic="true", ): - return self._run( + return self.run_condition( ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic ) @eqx.filter_jit - def srun( + def _grad( self, - ts: np.ndarray, - ts_dyn: np.ndarray, + ts: jnp.ndarray, + ts_dyn: jnp.ndarray, p: jnp.ndarray, - k: np.ndarray, - k_preeq: np.ndarray, - my: np.ndarray, - pscale: np.ndarray, + k: jnp.ndarray, + k_preeq: jnp.ndarray, + my: jnp.ndarray, + pscale: jnp.ndarray, dynamic="true", ): (llh, (x, obs, stats)), sllh = ( - jax.value_and_grad(self._run, 2, True) + jax.value_and_grad(self.run_condition, 2, True) )(ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic) return llh, sllh, (x, obs, stats) @eqx.filter_jit - def s2run( + def _hessian( self, - ts: np.ndarray, - ts_dyn: np.ndarray, + ts: jnp.ndarray, + ts_dyn: jnp.ndarray, p: jnp.ndarray, - k: np.ndarray, - k_preeq: np.ndarray, - my: np.ndarray, - pscale: np.ndarray, + k: jnp.ndarray, + k_preeq: jnp.ndarray, + my: jnp.ndarray, + pscale: jnp.ndarray, dynamic="true", ): (llh, (x, obs, stats)), sllh = ( - jax.value_and_grad(self._run, 2, True) + jax.value_and_grad(self.run_condition, 2, True) )(ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic) - s2llh = jax.hessian(self._run, 2, True)( + s2llh = jax.hessian(self.run_condition, 2, True)( ts, ts_dyn, p, @@ -391,16 +410,7 @@ def run_simulation( sensitivity_order: amici.SensitivityOrder = amici.SensitivityOrder.none, ): parameter_mapping = self.parameter_mappings[simulation_condition] - measurements_df = self.petab_problem.measurement_df - for v, k in zip( - simulation_condition, - ( - petab.SIMULATION_CONDITION_ID, - petab.PREEQUILIBRATION_CONDITION_ID, - ), - ): - measurements_df = measurements_df.query(f"{k} == '{v}'") - ts = _get_timepoints_with_replicates(measurements_df) + ts, my = self.measurements[simulation_condition] p = jnp.array( [ pval @@ -430,10 +440,7 @@ def run_simulation( if k in parameter_mapping.map_preeq_fix ] ) - my = _get_measurements_and_sigmas( - measurements_df, ts, self.observable_ids - )[0].flatten() - ts = np.array(ts) + ts_dyn = ts[np.isfinite(ts)] dynamic = "true" if len(ts_dyn) and np.max(ts_dyn) > 0 else "false" @@ -445,7 +452,7 @@ def run_simulation( ( rdata_kwargs["llh"], (rdata_kwargs["x"], rdata_kwargs["y"], rdata_kwargs["stats"]), - ) = self.run( + ) = self._fun( ts, ts_dyn, p, k_sim, k_preeq, my, pscale, dynamic=dynamic ) elif sensitivity_order == amici.SensitivityOrder.first: @@ -453,7 +460,7 @@ def run_simulation( rdata_kwargs["llh"], rdata_kwargs["sllh"], (rdata_kwargs["x"], rdata_kwargs["y"], rdata_kwargs["stats"]), - ) = self.srun( + ) = self._grad( ts, ts_dyn, p, k_sim, k_preeq, my, pscale, dynamic=dynamic ) elif sensitivity_order == amici.SensitivityOrder.second: @@ -462,7 +469,7 @@ def run_simulation( rdata_kwargs["sllh"], rdata_kwargs["s2llh"], (rdata_kwargs["x"], rdata_kwargs["y"], rdata_kwargs["stats"]), - ) = self.s2run( + ) = self._hessian( ts, ts_dyn, p, k_sim, k_preeq, my, pscale, dynamic=dynamic ) diff --git a/python/tests/test_jax.py b/python/tests/test_jax.py index 5898262f90..8c78253334 100644 --- a/python/tests/test_jax.py +++ b/python/tests/test_jax.py @@ -151,20 +151,20 @@ def check_fields_jax( ( r_jax["llh"], (r_jax["x"], r_jax["y"], r_jax["stats"]), - ) = jax_model.run(**kwargs) + ) = jax_model._fun(**kwargs) elif sensi_order == amici.SensitivityOrder.first: ( r_jax["llh"], r_jax["sllh"], (r_jax["x"], r_jax["y"], r_jax["stats"]), - ) = jax_model.srun(**kwargs) + ) = jax_model._grad(**kwargs) elif sensi_order == amici.SensitivityOrder.second: ( r_jax["llh"], r_jax["sllh"], r_jax["s2llh"], (r_jax["x"], r_jax["y"], r_jax["stats"]), - ) = jax_model.s2run(**kwargs) + ) = jax_model._hessian(**kwargs) for field in fields: for r_amici, r_jax in zip(rs_amici, [r_jax]):