diff --git a/python/sdist/amici/de_export.py b/python/sdist/amici/de_export.py index 793d746e9a..823f5f8ca1 100644 --- a/python/sdist/amici/de_export.py +++ b/python/sdist/amici/de_export.py @@ -319,7 +319,7 @@ def jnp_array_str(array) -> str: ), indent, ) - ) + )[indent:] for eq_name in eq_names }, **{ diff --git a/python/sdist/amici/jax.template.py b/python/sdist/amici/jax.template.py index a53ab2066a..08a546826f 100644 --- a/python/sdist/amici/jax.template.py +++ b/python/sdist/amici/jax.template.py @@ -8,17 +8,16 @@ class JAXModel_TPL_MODEL_NAME(JAXModel): def __init__(self): super().__init__() - @staticmethod - def xdot(t, x, args): + def xdot(self, t, x, args): pk, tcl = args TPL_X_SYMS = x TPL_PK_SYMS = pk TPL_TCL_SYMS = tcl - TPL_W_SYMS = JAXModel_TPL_MODEL_NAME._w(t, x, pk, tcl) + TPL_W_SYMS = self._w(t, x, pk, tcl) -TPL_XDOT_EQ + TPL_XDOT_EQ return TPL_XDOT_RET @@ -29,7 +28,7 @@ def _w(t, x, pk, tcl): TPL_PK_SYMS = pk TPL_TCL_SYMS = tcl -TPL_W_EQ + TPL_W_EQ return TPL_W_RET @@ -38,7 +37,7 @@ def x0(pk): TPL_PK_SYMS = pk -TPL_X0_EQ + TPL_X0_EQ return TPL_X0_RET @@ -47,7 +46,7 @@ def x_solver(x): TPL_X_RDATA_SYMS = x -TPL_X_SOLVER_EQ + TPL_X_SOLVER_EQ return TPL_X_SOLVER_RET @@ -57,7 +56,7 @@ def x_rdata(x, tcl): TPL_X_SYMS = x TPL_TCL_SYMS = tcl -TPL_X_RDATA_EQ + TPL_X_RDATA_EQ return TPL_X_RDATA_RET @@ -67,7 +66,7 @@ def tcl(x, pk): TPL_X_RDATA_SYMS = x TPL_PK_SYMS = pk -TPL_TOTAL_CL_EQ + TPL_TOTAL_CL_EQ return TPL_TOTAL_CL_RET @@ -77,7 +76,7 @@ def y(self, t, x, pk, tcl): TPL_PK_SYMS = pk TPL_W_SYMS = self._w(t, x, pk, tcl) -TPL_Y_EQ + TPL_Y_EQ return TPL_Y_RET @@ -86,7 +85,7 @@ def sigmay(self, y, pk): TPL_Y_SYMS = y -TPL_SIGMAY_EQ + TPL_SIGMAY_EQ return TPL_SIGMAY_RET @@ -94,16 +93,11 @@ def sigmay(self, y, pk): def llh(self, t, x, pk, tcl, my, iy): y = self.y(t, x, pk, tcl) TPL_Y_SYMS = y - sigmay = self.sigmay(y, pk) - TPL_SIGMAY_SYMS = sigmay + TPL_SIGMAY_SYMS = self.sigmay(y, pk) -TPL_JY_EQ + TPL_JY_EQ - return jnp.array([ - TPL_JY_RET.at[iy].get(), - y.at[iy].get(), - sigmay.at[iy].get() - ]) + return TPL_JY_RET.at[iy].get() @property def observable_ids(self): diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index ffd58ee8a1..f412faecac 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -3,7 +3,6 @@ import diffrax import equinox as eqx import jax.numpy as jnp -import numpy as np import jax # always use 64-bit precision. No-brainer on CPUs and GPUs don't make sense for stiff systems. @@ -16,10 +15,12 @@ class JAXModel(eqx.Module): JAXModel must provide model specific implementations of abstract methods. """ - @staticmethod @abstractmethod def xdot( - t: jnp.float_, x: jnp.ndarray, args: tuple[jnp.ndarray, jnp.ndarray] + self, + t: jnp.float_, + x: jnp.ndarray, + args: tuple[jnp.ndarray, jnp.ndarray], ) -> jnp.ndarray: """ Right-hand side of the ODE system. @@ -190,21 +191,6 @@ def parameter_ids(self) -> list[str]: """ ... - def _preeq(self, p, solver, controller, max_steps): - """ - Pre-equilibration of the model. - :param p: - parameters - :return: - Initial state vector - """ - x0 = self.x_solver(self.x0(p)) - tcl = self.tcl(x0, p) - return self._eq(p, tcl, x0, solver, controller, max_steps) - - def _posteq(self, p, x, tcl, solver, controller, max_steps): - return self._eq(p, tcl, x, solver, controller, max_steps) - def _eq(self, p, tcl, x0, solver, controller, max_steps): sol = diffrax.diffeqsolve( diffrax.ODETerm(self.xdot), @@ -216,12 +202,12 @@ def _eq(self, p, tcl, x0, solver, controller, max_steps): y0=x0, stepsize_controller=controller, max_steps=max_steps, + adjoint=diffrax.DirectAdjoint(), event=diffrax.Event(cond_fn=diffrax.steady_state_event()), ) - return sol.ys[-1, :] + return sol.ys[-1, :], sol.stats - def _solve(self, ts, p, x0, solver, controller, max_steps): - tcl = self.tcl(x0, p) + def _solve(self, p, ts, tcl, x0, solver, controller, max_steps, adjoint): sol = diffrax.diffeqsolve( diffrax.ODETerm(self.xdot), solver, @@ -229,14 +215,14 @@ def _solve(self, ts, p, x0, solver, controller, max_steps): t0=0.0, t1=ts[-1], dt0=None, - y0=self.x_solver(x0), + y0=x0, stepsize_controller=controller, max_steps=max_steps, - adjoint=diffrax.RecursiveCheckpointAdjoint(), + adjoint=adjoint, saveat=diffrax.SaveAt(ts=ts), throw=False, ) - return sol.ys, tcl, sol.stats + return sol.ys, sol.stats def _x_rdata(self, x, tcl): return jax.vmap(self.x_rdata, in_axes=(0, None))(x, tcl) @@ -246,62 +232,105 @@ def _outputs(self, ts, x, p, tcl, my, iys) -> jnp.float_: ts, x, p, tcl, my, iys ) + def _y(self, ts, xs, p, tcl, iys): + return jax.vmap( + lambda t, x, p, tcl, iy: self.y(t, x, p, tcl).at[iy].get(), + in_axes=(0, 0, None, None, 0), + )(ts, xs, p, tcl, iys) + + def _sigmay(self, ts, xs, p, tcl, iys): + return jax.vmap( + lambda t, x, p, tcl, iy: self.sigmay(self.y(t, x, p, tcl), p) + .at[iy] + .get(), + in_axes=(0, 0, None, None, 0), + )(ts, xs, p, tcl, iys) + # @eqx.filter_jit def simulate_condition( self, - ts: np.ndarray, - ts_dyn: np.ndarray, - my: np.ndarray, - iys: np.ndarray, p: jnp.ndarray, p_preeq: jnp.ndarray, - dynamic: bool, + ts_preeq: jnp.ndarray, + ts_dyn: jnp.ndarray, + ts_posteq: jnp.ndarray, + my: jnp.ndarray, + iys: jnp.ndarray, solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, + adjoint: diffrax.AbstractAdjoint, max_steps: int, + ret: str = "llh", ): # Pre-equilibration if p_preeq.shape[0] > 0: - x0 = self._preeq(p_preeq, solver, controller, max_steps) + x0 = self.x0(p_preeq) + tcl = self.tcl(x0, p_preeq) + current_x = self.x_solver(x0) + current_x, stats_preeq = self._eq( + p_preeq, tcl, current_x, solver, controller, max_steps + ) + # update tcl with new parameters + tcl = self.tcl(self.x_rdata(current_x, tcl), p) else: x0 = self.x0(p) + current_x = self.x_solver(x0) + stats_preeq = None + + tcl = self.tcl(x0, p) + x_preq = jnp.repeat( + current_x.reshape(1, -1), ts_preeq.shape[0], axis=0 + ) # Dynamic simulation - if dynamic: - x, tcl, stats = self._solve( - ts_dyn, p, x0, solver, controller, max_steps + if ts_dyn.shape[0] > 0: + x_dyn, stats_dyn = self._solve( + p, + ts_dyn, + tcl, + current_x, + solver, + controller, + max_steps, + adjoint, ) + current_x = x_dyn[-1, :] else: - x = jnp.repeat( - self.x_solver(x0).reshape(1, -1), - len(ts_dyn), - axis=0, + x_dyn = jnp.repeat( + current_x.reshape(1, -1), ts_dyn.shape[0], axis=0 ) - tcl = self.tcl(x0, p) - stats = None + stats_dyn = None # Post-equilibration - if len(ts) > len(ts_dyn): - if len(ts_dyn) > 0: - x_final = x[-1, :] - else: - x_final = self.x_solver(x0) - x_posteq = self._posteq( - p, x_final, tcl, solver, controller, max_steps - ) - x_posteq = jnp.repeat( - x_posteq.reshape(1, -1), - len(ts) - len(ts_dyn), - axis=0, + if ts_posteq.shape[0] > 0: + current_x, stats_posteq = self._eq( + p, tcl, current_x, solver, controller, max_steps ) - if len(ts_dyn) > 0: - x = jnp.concatenate((x, x_posteq), axis=0) - else: - x = x_posteq - - outputs = self._outputs(ts, x, p, tcl, my, iys) - llh = -jnp.sum(outputs[:, 0]) - obs = outputs[:, 1] - sigmay = outputs[:, 2] - x_rdata = jnp.stack(self._x_rdata(x, tcl), axis=1) - return llh, dict(llh=llh, x=x_rdata, y=obs, sigmay=sigmay, stats=stats) + else: + stats_posteq = None + + x_posteq = jnp.repeat( + current_x.reshape(1, -1), ts_posteq.shape[0], axis=0 + ) + + ts = jnp.concatenate((ts_preeq, ts_dyn, ts_posteq), axis=0) + x = jnp.concatenate((x_preq, x_dyn, x_posteq), axis=0) + + llhs = self._outputs(ts, x, p, tcl, my, iys) + llh = -jnp.sum(llhs) + return { + "llh": llh, + "llhs": llhs, + "x": self._x_rdata(x, tcl), + "x_solver": x, + "y": self._y(ts, x, p, tcl, iys), + "sigmay": self._sigmay(ts, x, p, tcl, iys), + "x0": self.x_rdata(x_preq[-1, :], tcl), + "x0_solver": x_preq[-1, :], + "tcl": tcl, + "res": self._y(ts, x, p, tcl, iys) - my, + }[ret], dict( + stats_preeq=stats_preeq, + stats_dyn=stats_dyn, + stats_posteq=stats_posteq, + ) diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index 6bf090d114..deb1d12d92 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -13,6 +13,7 @@ import diffrax import equinox as eqx +import jax.lax import jax.numpy as jnp import numpy as np import pandas as pd @@ -22,7 +23,7 @@ ParameterMappingForCondition, create_parameter_mapping, ) -from amici.jax.model import JAXModel +from amici.jax.model import JAXModel, simulate_condition def jax_unscale( @@ -35,7 +36,7 @@ def jax_unscale( parameter: Parameter to be unscaled. scale_str: - One of ``'lin'`` (synonymous with ``''``), ``'log'``, ``'log10'``. + One of ``petab.LIN``, ``petab.LOG``, ``petab.LOG10``. Returns: The unscaled parameter. @@ -51,12 +52,6 @@ def jax_unscale( class JAXProblem(eqx.Module): """ - :ivar solver: - Diffrax solver to use for model simulation - :ivar controller: - Step-size controller to use for model simulation - :ivar max_steps: - Maximum number of steps to take during a simulation :ivar parameters: Values for the model parameters. Only populated after setting the PEtab problem via :meth:`set_petab_problem`. Do not change dimensions, values may be changed during, e.g. model training. @@ -72,13 +67,11 @@ class JAXProblem(eqx.Module): parameters: jnp.ndarray model: JAXModel - parameter_mappings: dict[tuple[str], ParameterMappingForCondition] = ( - eqx.field(static=True) - ) + parameter_mappings: dict[tuple[str], ParameterMappingForCondition] measurements: dict[ tuple[str], - tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, str], - ] = eqx.field(static=True) + tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray], + ] petab_problem: petab.Problem def __init__(self, model: JAXModel, petab_problem: petab.Problem): @@ -122,30 +115,31 @@ def _get_measurements(self, simulation_conditions: pd.DataFrame): """ measurements = dict() for _, simulation_condition in simulation_conditions.iterrows(): - measurements_df = self.petab_problem.measurement_df - for k, v in simulation_condition.items(): - measurements_df = measurements_df.query(f"{k} == '{v}'") + query = " & ".join( + [f"{k} == '{v}'" for k, v in simulation_condition.items()] + ) + m = self.petab_problem.measurement_df.query(query) - measurements_df.sort_values(by=petab.TIME, inplace=True) + m.sort_values(by=petab.TIME, inplace=True) - ts = measurements_df[petab.TIME].values - ts_dyn = [t for t in ts if np.isfinite(t)] - my = measurements_df[petab.MEASUREMENT].values + ts = m[petab.TIME].values + ts_preeq = ts[np.isfinite(ts) & (ts == 0)] + ts_dyn = ts[np.isfinite(ts) & (ts > 0)] + ts_posteq = ts[np.logical_not(np.isfinite(ts))] + my = m[petab.MEASUREMENT].values iys = np.array( [ self.model.observable_ids.index(oid) - for oid in measurements_df[petab.OBSERVABLE_ID].values + for oid in m[petab.OBSERVABLE_ID].values ] ) - # using strings here prevents tracing in jax - dynamic = ts_dyn and max(ts_dyn) > 0 measurements[tuple(simulation_condition)] = ( - np.array(ts), - np.array(ts_dyn), + ts_preeq, + ts_dyn, + ts_posteq, my, iys, - dynamic, ) return measurements @@ -236,21 +230,24 @@ def run_simulation( controller: diffrax.AbstractStepSizeController, max_steps: int, ): - ts, ts_dyn, my, iys, dynamic = self.measurements[simulation_condition] + ts_preeq, ts_dyn, ts_posteq, my, iys = self.measurements[ + simulation_condition + ] p = self.load_parameters(simulation_condition[0]) p_preeq = ( self.load_parameters(simulation_condition[1]) if len(simulation_condition) > 1 else jnp.array([]) ) - return self.model.simulate_condition( - ts, - ts_dyn, - my, - iys, + return simulate_condition( p, p_preeq, - dynamic, + self.model, + jax.lax.stop_gradient(jnp.array(ts_preeq)), + jax.lax.stop_gradient(jnp.array(ts_dyn)), + jax.lax.stop_gradient(jnp.array(ts_posteq)), + jax.lax.stop_gradient(jnp.array(my)), + jax.lax.stop_gradient(jnp.array(iys)), solver, controller, max_steps, diff --git a/python/tests/test_jax.py b/python/tests/test_jax.py index 8c78253334..543f8f0544 100644 --- a/python/tests/test_jax.py +++ b/python/tests/test_jax.py @@ -5,6 +5,8 @@ import amici.jax import jax.numpy as jnp +import jax +import diffrax import numpy as np from amici.pysb_import import pysb2amici @@ -109,22 +111,16 @@ def _test_model(model_module, ts, p, k): amici_solver.setSensitivityOrder(amici.SensitivityOrder.first) rs_amici = amici.runAmiciSimulations(amici_model, amici_solver, [edata]) - check_fields_jax(rs_amici, jax_model, edata, ["x", "y", "llh"]) - check_fields_jax( - rs_amici, - jax_model, - edata, - ["x", "y", "llh", "sllh"], - sensi_order=amici.SensitivityOrder.first, + rs_amici, jax_model, edata, ["x", "y", "llh", "res", "x0"] ) check_fields_jax( rs_amici, jax_model, edata, - ["x", "y", "llh", "sllh"], - sensi_order=amici.SensitivityOrder.second, + ["sllh", "sx0", "sx", "sres", "sy"], + sensi_order=amici.SensitivityOrder.first, ) @@ -136,41 +132,81 @@ def check_fields_jax( sensi_order=amici.SensitivityOrder.none, ): r_jax = dict() - kwargs = { - "ts": np.array(edata.getTimepoints()), - "ts_dyn": np.array(edata.getTimepoints()), - "p": np.array(edata.parameters), - "k": np.array(edata.fixedParameters), - "k_preeq": np.array([]), - "my": np.array(edata.getObservedData()).reshape( - np.array(edata.getTimepoints()).shape[0], -1 - ), - "pscale": np.array(edata.pscale), - } - if sensi_order == amici.SensitivityOrder.none: - ( - r_jax["llh"], - (r_jax["x"], r_jax["y"], r_jax["stats"]), - ) = jax_model._fun(**kwargs) - elif sensi_order == amici.SensitivityOrder.first: - ( - r_jax["llh"], - r_jax["sllh"], - (r_jax["x"], r_jax["y"], r_jax["stats"]), - ) = jax_model._grad(**kwargs) - elif sensi_order == amici.SensitivityOrder.second: - ( - r_jax["llh"], - r_jax["sllh"], - r_jax["s2llh"], - (r_jax["x"], r_jax["y"], r_jax["stats"]), - ) = jax_model._hessian(**kwargs) + ts = np.array(edata.getTimepoints()) + my = np.array(edata.getObservedData()).reshape(len(ts), -1) + ts = np.repeat(ts.reshape(-1, 1), my.shape[1], axis=1) + iys = np.repeat(np.arange(my.shape[1]).reshape(1, -1), len(ts), axis=0) + my = my.flatten() + ts = ts.flatten() + iys = iys.flatten() + + ts_preeq = ts[ts == 0] + ts_dyn = ts[ts > 0] + ts_posteq = np.array([]) + p = jnp.array(list(edata.parameters) + list(edata.fixedParameters)) + args = ( + jnp.array([]), # p_preeq + jnp.array(ts_preeq), # ts_preeq + jnp.array(ts_dyn), # ts_dyn + jnp.array(ts_posteq), # ts_posteq + jnp.array(my), # my + jnp.array(iys), # iys + diffrax.Kvaerno5(), # solver + diffrax.PIDController(atol=1e-8, rtol=1e-8), # controller + diffrax.RecursiveCheckpointAdjoint(), # adjoint + 2**8, # max_steps + ) + fun = jax_model.simulate_condition + + for output in ["llh", "x0", "x", "y", "res"]: + oargs = (*args[:-2], diffrax.DirectAdjoint(), 2**8, output) + if sensi_order == amici.SensitivityOrder.none: + r_jax[output] = fun(p, *oargs)[0] + if sensi_order == amici.SensitivityOrder.first: + if output == "llh": + r_jax[f"s{output}"] = jax.grad(fun, has_aux=True)(p, *args)[0] + else: + r_jax[f"s{output}"] = jax.jacfwd(fun, has_aux=True)(p, *oargs)[ + 0 + ] for field in fields: for r_amici, r_jax in zip(rs_amici, [r_jax]): + actual = r_jax[field] + desired = r_amici[field] + if field == "x": + actual = actual[iys == 0, :] + if field == "y": + actual = np.stack( + [actual[iys == iy] for iy in sorted(np.unique(iys))], + axis=1, + ) + elif field == "sllh": + actual = actual[: len(edata.parameters)] + elif field == "sx": + actual = np.permute_dims( + actual[iys == 0, :, : len(edata.parameters)], (0, 2, 1) + ) + elif field == "sy": + actual = np.permute_dims( + np.stack( + [ + actual[iys == iy, : len(edata.parameters)] + for iy in sorted(np.unique(iys)) + ], + axis=1, + ), + (0, 2, 1), + ) + elif field == "sx0": + actual = actual[:, : len(edata.parameters)].T + elif field == "sres": + actual = actual[:, : len(edata.parameters)] + assert_allclose( - actual=r_amici[field], - desired=r_jax[field], - atol=1e-6, - rtol=1e-6, + actual=actual, + desired=desired, + atol=1e-5, + rtol=1e-5, + err_msg=f"field {field} does not match", )