diff --git a/pytest.ini b/pytest.ini index adbf313922..8cc45e0fd9 100644 --- a/pytest.ini +++ b/pytest.ini @@ -12,6 +12,7 @@ filterwarnings = ignore:Conservation laws for non-constant species in models with Species-AssignmentRules are currently not supported and will be turned off.:UserWarning ignore:Conservation laws for non-constant species in combination with parameterized stoichiometric coefficients are not currently supported and will be turned off.:UserWarning ignore:Support for PEtab2.0 is experimental!:UserWarning + ignore:The JAX module is experimental and the API may change in the future.:ImportWarning # hundreds of SBML <=5.17 warnings ignore:.*inspect.getargspec\(\) is deprecated.*:DeprecationWarning # pysb warnings diff --git a/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb b/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb index 855860e242..f6a4f10e98 100644 --- a/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb +++ b/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb @@ -151,7 +151,7 @@ " results (dict): Simulation results from run_simulations.\n", " \"\"\"\n", " # Extract the simulation results for the specific condition\n", - " sim_results = results[simulation_condition][1]\n", + " sim_results = results[simulation_condition]\n", "\n", " # Create a new figure for the state trajectories\n", " plt.figure(figsize=(8, 6))\n", @@ -357,14 +357,12 @@ "simulation_condition = (\"model1_data1\",)\n", "\n", "# Load condition-specific data\n", - "ts_preeq, ts_dyn, ts_posteq, my, iys = jax_problem._measurements[\n", + "ts_init, ts_dyn, ts_posteq, my, iys = jax_problem._measurements[\n", " simulation_condition\n", "]\n", "\n", "# Load parameters for the specified condition\n", "p = jax_problem.load_parameters(simulation_condition[0])\n", - "# Disable preequilibration\n", - "p_preeq = jnp.array([])\n", "\n", "\n", "# Define a function to compute the gradient with respect to dynamic timepoints\n", @@ -372,12 +370,12 @@ "def grad_ts_dyn(tt):\n", " return jax_problem.model.simulate_condition(\n", " p=p,\n", - " p_preeq=p_preeq,\n", - " ts_preeq=ts_preeq,\n", + " ts_init=ts_init,\n", " ts_dyn=tt,\n", " ts_posteq=ts_posteq,\n", " my=jnp.array(my),\n", " iys=jnp.array(iys),\n", + " x_preeq=jnp.array([]),\n", " solver=diffrax.Kvaerno5(),\n", " controller=diffrax.PIDController(atol=1e-8, rtol=1e-8),\n", " max_steps=2**10,\n", @@ -489,7 +487,6 @@ "amici_model = import_petab_problem(\n", " petab_problem,\n", " verbose=False,\n", - " compile_=True,\n", " jax=False, # load the amici model this time\n", ")\n", "\n", diff --git a/python/sdist/amici/jax/__init__.py b/python/sdist/amici/jax/__init__.py index 6578c38c6f..2ad97bddcf 100644 --- a/python/sdist/amici/jax/__init__.py +++ b/python/sdist/amici/jax/__init__.py @@ -1,7 +1,22 @@ -"""Interface to facilitate AMICI generated models using JAX""" +""" +JAX +--- + +This module provides an interface to generate and use AMICI models with JAX. Please note that this module is +experimental, the API may substantially change in the future. Use at your own risk and do not expect backward +compatibility. +""" + +from warnings import warn from amici.jax.petab import JAXProblem, run_simulations from amici.jax.model import JAXModel from amici.jax.nn import generate_equinox +warn( + "The JAX module is experimental and the API may change in the future.", + ImportWarning, + stacklevel=2, +) + __all__ = ["JAXModel", "JAXProblem", "run_simulations", "generate_equinox"] diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index 47790c98a5..7d2ff15709 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -428,12 +428,12 @@ def _sigmays( def simulate_condition( self, p: jt.Float[jt.Array, "np"], - p_preeq: jt.Float[jt.Array, "*np"], - ts_preeq: jt.Float[jt.Array, "nt_preeq"], + ts_init: jt.Float[jt.Array, "nt_preeq"], ts_dyn: jt.Float[jt.Array, "nt_dyn"], ts_posteq: jt.Float[jt.Array, "nt_posteq"], my: jt.Float[jt.Array, "nt"], iys: jt.Int[jt.Array, "nt"], + x_preeq: jt.Float[jt.Array, "nx"], solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, adjoint: diffrax.AbstractAdjoint, @@ -445,12 +445,9 @@ def simulate_condition( :param p: parameters for simulation ordered according to ids in :ivar parameter_ids: - :param p_preeq: - parameters for pre-equilibration ordered according to ids in :ivar parameter_ids:. May be empty to - disable pre-equilibration. - :param ts_preeq: - time points for pre-equilibration. Usually valued 0.0, but needs to be shaped according to - the number of observables that are evaluated after pre-equilibration. + :param ts_init: + time points that do not require simulation. Usually valued 0.0, but needs to be shaped according to + the number of observables that are evaluated before dynamic simulation. :param ts_dyn: time points for dynamic simulation. Usually valued > 0.0 and sorted in monotonically increasing order. Duplicate time points are allowed to facilitate the evaluation of multiple observables at specific time @@ -487,24 +484,16 @@ def simulate_condition( output according to `ret` and statistics """ # Pre-equilibration - if p_preeq.shape[0] > 0: - x0 = self._x0(p_preeq) - tcl = self._tcl(x0, p_preeq) - current_x = self._x_solver(x0) - current_x, stats_preeq = self._eq( - p_preeq, tcl, current_x, solver, controller, max_steps - ) + if x_preeq.shape[0] > 0: + current_x = self._x_solver(x_preeq) # update tcl with new parameters - tcl = self._tcl(self._x_rdata(current_x, tcl), p) + tcl = self._tcl(x_preeq, p) else: x0 = self._x0(p) current_x = self._x_solver(x0) - stats_preeq = None tcl = self._tcl(x0, p) - x_preq = jnp.repeat( - current_x.reshape(1, -1), ts_preeq.shape[0], axis=0 - ) + x_preq = jnp.repeat(current_x.reshape(1, -1), ts_init.shape[0], axis=0) # Dynamic simulation if ts_dyn.shape[0] > 0: @@ -537,7 +526,7 @@ def simulate_condition( current_x.reshape(1, -1), ts_posteq.shape[0], axis=0 ) - ts = jnp.concatenate((ts_preeq, ts_dyn, ts_posteq), axis=0) + ts = jnp.concatenate((ts_init, ts_dyn, ts_posteq), axis=0) x = jnp.concatenate((x_preq, x_dyn, x_posteq), axis=0) nllhs = self._nllhs(ts, x, p, tcl, my, iys) @@ -556,11 +545,42 @@ def simulate_condition( }[ret], dict( ts=ts, x=x, - stats_preeq=stats_preeq, stats_dyn=stats_dyn, stats_posteq=stats_posteq, ) + @eqx.filter_jit + def preequilibrate_condition( + self, + p: jt.Float[jt.Array, "np"], + solver: diffrax.AbstractSolver, + controller: diffrax.AbstractStepSizeController, + max_steps: int | jnp.int_, + ) -> tuple[jt.Float[jt.Array, "nx"], dict]: + r""" + Simulate a condition. + + :param p: + parameters for simulation ordered according to ids in :ivar parameter_ids: + :param solver: + ODE solver + :param controller: + step size controller + :param max_steps: + maximum number of solver steps + :return: + pre-equilibrated state variables and statistics + """ + # Pre-equilibration + x0 = self._x0(p) + tcl = self._tcl(x0, p) + current_x = self._x_solver(x0) + current_x, stats_preeq = self._eq( + p, tcl, current_x, solver, controller, max_steps + ) + + return self._x_rdata(current_x, tcl), dict(stats_preeq=stats_preeq) + def safe_log(x: jnp.float_) -> jnp.float_: """ diff --git a/python/sdist/amici/jax/ode_export.py b/python/sdist/amici/jax/ode_export.py index f36f67ab85..2fcc1aa718 100644 --- a/python/sdist/amici/jax/ode_export.py +++ b/python/sdist/amici/jax/ode_export.py @@ -247,12 +247,10 @@ def _generate_jax_code(self) -> None: for net in self.hybridisation.keys() ), } - outdir = self.model_path / (self.model_name + "_jax") - outdir.mkdir(parents=True, exist_ok=True) apply_template( Path(amiciModulePath) / "jax" / "jax.template.py", - outdir / "__init__.py", + self.model_path / "__init__.py", tpl_data, ) @@ -280,6 +278,7 @@ def set_paths(self, output_dir: str | Path | None = None) -> None: output_dir = Path(os.getcwd()) / f"amici-{self.model_name}" self.model_path = Path(output_dir).resolve() + self.model_path.mkdir(parents=True, exist_ok=True) def set_name(self, model_name: str) -> None: """ diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index 75e346bfe6..3e653b72fc 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -155,7 +155,7 @@ def _get_parameter_mappings( def _get_measurements( self, simulation_conditions: pd.DataFrame ) -> dict[ - tuple[str], + tuple[str, ...], tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray], ]: """ @@ -412,20 +412,22 @@ def run_simulation( solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, max_steps: jnp.int_, + x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]), # noqa: F821, F722 ret: str = "llh", ) -> tuple[jnp.float_, dict]: """ Run a simulation for a given simulation condition. :param simulation_condition: - Tuple of simulation conditions to run the simulation for. can be a single string (simulation only) or a - tuple of strings (pre-equilibration followed by simulation). + Simulation condition to run simulation for. :param solver: ODE solver to use for simulation :param controller: Step size controller to use for simulation :param max_steps: Maximum number of steps to take during simulation + :param x_preeq: + Pre-equilibration state if available :param ret: which output to return. Valid values are - `llh`: log-likelihood (default) @@ -445,19 +447,14 @@ def run_simulation( simulation_condition ] p = self.load_parameters(simulation_condition[0]) - p_preeq = ( - self.load_parameters(simulation_condition[1]) - if len(simulation_condition) > 1 - else jnp.array([]) - ) return self.model.simulate_condition( p=p, - p_preeq=p_preeq, - ts_preeq=jax.lax.stop_gradient(jnp.array(ts_preeq)), + ts_init=jax.lax.stop_gradient(jnp.array(ts_preeq)), ts_dyn=jax.lax.stop_gradient(jnp.array(ts_dyn)), ts_posteq=jax.lax.stop_gradient(jnp.array(ts_posteq)), my=jax.lax.stop_gradient(jnp.array(my)), iys=jax.lax.stop_gradient(jnp.array(iys)), + x_preeq=x_preeq, solver=solver, controller=controller, max_steps=max_steps, @@ -467,10 +464,39 @@ def run_simulation( ret=ret, ) + def run_preequilibration( + self, + simulation_condition: str, + solver: diffrax.AbstractSolver, + controller: diffrax.AbstractStepSizeController, + max_steps: jnp.int_, + ) -> tuple[jt.Float[jt.Array, "nx"], dict]: # noqa: F821 + """ + Run a pre-equilibration simulation for a given simulation condition. + + :param simulation_condition: + Simulation condition to run simulation for. + :param solver: + ODE solver to use for simulation + :param controller: + Step size controller to use for simulation + :param max_steps: + Maximum number of steps to take during simulation + :return: + Pre-equilibration state + """ + p = self.load_parameters(simulation_condition) + return self.model.preequilibrate_condition( + p=p, + solver=solver, + controller=controller, + max_steps=max_steps, + ) + def run_simulations( problem: JAXProblem, - simulation_conditions: Iterable[tuple] | None = None, + simulation_conditions: Iterable[tuple[str, ...]] | None = None, solver: diffrax.AbstractSolver = diffrax.Kvaerno5(), controller: diffrax.AbstractStepSizeController = diffrax.PIDController( rtol=1e-8, @@ -513,12 +539,28 @@ def run_simulations( if simulation_conditions is None: simulation_conditions = problem.get_all_simulation_conditions() + preeqs = { + sc: problem.run_preequilibration(sc, solver, controller, max_steps) + # only run preequilibration once per condition + for sc in {sc[1] for sc in simulation_conditions if len(sc) > 1} + } + results = { - sc: problem.run_simulation(sc, solver, controller, max_steps, ret) + sc: problem.run_simulation( + sc, + solver, + controller, + max_steps, + preeqs.get(sc[1])[0] if len(sc) > 1 else jnp.array([]), + ret, + ) for sc in simulation_conditions } if ret == "llh": output = sum(llh for llh, _ in results.values()) else: output = {sc: res for sc, (res, _) in results.items()} - return output, results + return output, { + sc: res[1] | preeqs[sc[1]][1] if len(sc) > 1 else res[1] + for sc, res in results.items() + } diff --git a/python/sdist/amici/petab/import_helpers.py b/python/sdist/amici/petab/import_helpers.py index daa902efb0..d42e99b1e3 100644 --- a/python/sdist/amici/petab/import_helpers.py +++ b/python/sdist/amici/petab/import_helpers.py @@ -138,10 +138,9 @@ def _can_import_model( Check whether a module of that name can already be imported. """ # try to import (in particular checks version) - suffix = "_jax" if jax else "" try: model_module = amici.import_model_module( - model_name + suffix, model_output_dir + *_get_package_name_and_path(model_name, model_output_dir, jax) ) except ModuleNotFoundError: return False @@ -271,3 +270,24 @@ def check_model( "the current model might also resolve this. Parameters: " f"{amici_ids_free_required.difference(amici_ids_free)}" ) + + +def _get_package_name_and_path( + model_name: str, model_output_dir: str | Path, jax: bool = False +) -> tuple[str, Path]: + """ + Get the package name and path for the generated model module. + + :param model_name: + Name of the model + :param model_output_dir: + Target directory for the generated model module + :param jax: + Whether to generate the paths for a JAX or CPP model + :return: + """ + if jax: + outdir = Path(model_output_dir) + return outdir.stem, outdir.parent + else: + return model_name, Path(model_output_dir) diff --git a/python/sdist/amici/petab/petab_import.py b/python/sdist/amici/petab/petab_import.py index c7ba576c21..c23736cd4a 100644 --- a/python/sdist/amici/petab/petab_import.py +++ b/python/sdist/amici/petab/petab_import.py @@ -16,7 +16,12 @@ from petab.v1.models import MODEL_TYPE_PYSB, MODEL_TYPE_SBML from ..logging import get_logger -from .import_helpers import _can_import_model, _create_model_name, check_model +from .import_helpers import ( + _can_import_model, + _create_model_name, + check_model, + _get_package_name_and_path, +) from .sbml_import import import_model_sbml try: @@ -114,7 +119,7 @@ def import_petab_problem( from .sbml_import import _create_model_output_dir_name model_output_dir = _create_model_output_dir_name( - petab_problem.sbml_model, model_name + petab_problem.sbml_model, model_name, jax=jax ) else: model_output_dir = os.path.abspath(model_output_dir) @@ -136,7 +141,7 @@ def import_petab_problem( ) # remove folder if exists - if os.path.exists(model_output_dir): + if not jax and os.path.exists(model_output_dir): shutil.rmtree(model_output_dir) logger.info(f"Compiling model {model_name} to {model_output_dir}.") @@ -193,9 +198,8 @@ def import_petab_problem( ) # import model - suffix = "_jax" if jax else "" model_module = amici.import_model_module( - model_name + suffix, model_output_dir + *_get_package_name_and_path(model_name, model_output_dir, jax=jax) ) if jax: diff --git a/python/sdist/amici/petab/sbml_import.py b/python/sdist/amici/petab/sbml_import.py index e6c0b02b32..4bd9842fdf 100644 --- a/python/sdist/amici/petab/sbml_import.py +++ b/python/sdist/amici/petab/sbml_import.py @@ -589,7 +589,9 @@ def _get_fixed_parameters_sbml( def _create_model_output_dir_name( - sbml_model: "libsbml.Model", model_name: str | None = None + sbml_model: "libsbml.Model", + model_name: str | None = None, + jax: bool = False, ) -> Path: """ Find a folder for storing the compiled amici model. @@ -600,12 +602,13 @@ def _create_model_output_dir_name( BASE_DIR = Path("amici_models").absolute() BASE_DIR.mkdir(exist_ok=True) # try model_name + suffix = "_jax" if jax else "" if model_name: - return BASE_DIR / model_name + return BASE_DIR / (model_name + suffix) # try sbml model id if sbml_model_id := sbml_model.getId(): - return BASE_DIR / sbml_model_id + return BASE_DIR / (sbml_model_id + suffix) # create random folder name return Path(tempfile.mkdtemp(dir=BASE_DIR)) diff --git a/python/tests/test_jax.py b/python/tests/test_jax.py index 8f4c68510b..ce7018e078 100644 --- a/python/tests/test_jax.py +++ b/python/tests/test_jax.py @@ -47,7 +47,7 @@ def test_conversion(): module_name=model.name, module_path=outdir ) jax_module = amici.import_model_module( - module_name=model.name + "_jax", module_path=outdir + module_name=Path(outdir).stem, module_path=Path(outdir).parent ) ts = tuple(np.linspace(0, 1, 10)) @@ -108,7 +108,7 @@ def test_dimerization(): module_name=model.name, module_path=outdir ) jax_module = amici.import_model_module( - module_name=model.name + "_jax", module_path=outdir + module_name=Path(outdir).stem, module_path=Path(outdir).parent ) ts = tuple(np.linspace(0, 1, 10)) @@ -178,7 +178,7 @@ def check_fields_jax( ts = ts.flatten() iys = iys.flatten() - ts_preeq = ts[ts == 0] + ts_init = ts[ts == 0] ts_dyn = ts[ts > 0] ts_posteq = np.array([]) @@ -188,31 +188,37 @@ def check_fields_jax( } p = jnp.array([par_dict[par_id] for par_id in jax_model.parameter_ids]) - args = ( - jnp.array([]), # p_preeq - jnp.array(ts_preeq), # ts_preeq - jnp.array(ts_dyn), # ts_dyn - jnp.array(ts_posteq), # ts_posteq - jnp.array(my), # my - jnp.array(iys), # iys - diffrax.Kvaerno5(), # solver - diffrax.PIDController(atol=ATOL_SIM, rtol=RTOL_SIM), # controller - diffrax.RecursiveCheckpointAdjoint(), # adjoint - 2**8, # max_steps - ) + kwargs = { + "ts_init": jnp.array(ts_init), + "ts_dyn": jnp.array(ts_dyn), + "ts_posteq": jnp.array(ts_posteq), + "my": jnp.array(my), + "iys": jnp.array(iys), + "x_preeq": jnp.array([]), + "solver": diffrax.Kvaerno5(), + "controller": diffrax.PIDController(atol=ATOL_SIM, rtol=RTOL_SIM), + "adjoint": diffrax.RecursiveCheckpointAdjoint(), + "max_steps": 2**8, # max_steps + } fun = beartype(jax_model.simulate_condition) for output in ["llh", "x0", "x", "y", "res"]: - oargs = (*args[:-2], diffrax.DirectAdjoint(), 2**8, output) + okwargs = kwargs | { + "adjoint": diffrax.DirectAdjoint(), + "max_steps": 2**8, + "ret": output, + } if sensi_order == amici.SensitivityOrder.none: - r_jax[output] = fun(p, *oargs)[0] + r_jax[output] = fun(p, **okwargs)[0] if sensi_order == amici.SensitivityOrder.first: if output == "llh": - r_jax[f"s{output}"] = jax.grad(fun, has_aux=True)(p, *args)[0] - else: - r_jax[f"s{output}"] = jax.jacfwd(fun, has_aux=True)(p, *oargs)[ + r_jax[f"s{output}"] = jax.grad(fun, has_aux=True)(p, **kwargs)[ 0 ] + else: + r_jax[f"s{output}"] = jax.jacfwd(fun, has_aux=True)( + p, **okwargs + )[0] amici_par_idx = np.array( [jax_model.parameter_ids.index(par_id) for par_id in parameter_ids] diff --git a/tests/benchmark-models/test_petab_benchmark.py b/tests/benchmark-models/test_petab_benchmark.py index 14760ec0ed..f22e59e915 100644 --- a/tests/benchmark-models/test_petab_benchmark.py +++ b/tests/benchmark-models/test_petab_benchmark.py @@ -328,7 +328,7 @@ def test_jax_llh(benchmark_problem): jax_model = import_petab_problem( petab_problem, - model_output_dir=benchmark_outdir / problem_id, + model_output_dir=benchmark_outdir / (problem_id + "_jax"), jax=True, ) jax_problem = JAXProblem(jax_model, petab_problem)