diff --git a/documentation/conf.py b/documentation/conf.py index c86a145f9d..4445c62069 100644 --- a/documentation/conf.py +++ b/documentation/conf.py @@ -206,6 +206,7 @@ def install_doxygen(): "numpy": ("https://numpy.org/devdocs/", None), "sympy": ("https://docs.sympy.org/latest/", None), "python": ("https://docs.python.org/3", None), + "jax": ["https://jax.readthedocs.io/en/latest/", None], } # Add notebooks prolog with binder links diff --git a/documentation/python_modules.rst b/documentation/python_modules.rst index 2607447f0d..096dd0735f 100644 --- a/documentation/python_modules.rst +++ b/documentation/python_modules.rst @@ -25,6 +25,7 @@ AMICI Python API amici.petab_objective amici.petab_simulate amici.import_utils + amici.jax amici.de_export amici.de_model amici.de_model_components diff --git a/python/sdist/amici/de_export.py b/python/sdist/amici/de_export.py index 1bace90510..4865851265 100644 --- a/python/sdist/amici/de_export.py +++ b/python/sdist/amici/de_export.py @@ -21,6 +21,7 @@ TYPE_CHECKING, Literal, ) +from itertools import chain import sympy as sp @@ -300,30 +301,38 @@ def jnp_array_str(array) -> str: return f"jnp.array([{elems}])" + # replaces Heaviside variables with corresponding functions + subs_heaviside = dict( + zip( + self.model.sym("h"), + [sp.Heaviside(x) for x in self.model.eq("root")], + strict=True, + ) + ) + # replaces observables with a generic my variable + subs_observables = dict( + zip( + self.model.sym("my"), + [sp.Symbol("my")] * len(self.model.sym("my")), + strict=True, + ) + ) + tpl_data = { + # assign named variable using corresponding algebraic formula (function body) **{ f"{eq_name.upper()}_EQ": "\n".join( self._code_printer_jax._get_sym_lines( (str(strip_pysb(s)) for s in self.model.sym(eq_name)), self.model.eq(eq_name).subs( - dict( - zip( - list(self.model.sym("h")) - + list(self.model.sym("my")), - [ - sp.Heaviside(x) - for x in self.model.eq("root") - ] - + [sp.Symbol("my")] - * len(self.model.sym("my")), - ) - ) + {**subs_heaviside, **subs_observables} ), indent, ) - )[indent:] + )[indent:] # remove indent for first line for eq_name in eq_names }, + # create jax array from concatenation of named variables **{ f"{eq_name.upper()}_RET": jnp_array_str( strip_pysb(s) for s in self.model.sym(eq_name) @@ -332,6 +341,7 @@ def jnp_array_str(array) -> str: else "jnp.array([])" for eq_name in eq_names }, + # assign named variables from a jax array **{ f"{sym_name.upper()}_SYMS": "".join( str(strip_pysb(s)) + ", " for s in self.model.sym(sym_name) @@ -340,6 +350,7 @@ def jnp_array_str(array) -> str: else "_" for sym_name in sym_names }, + # tuple of variable names (ids as they are unique) **{ f"{sym_name.upper()}_IDS": "".join( f'"{strip_pysb(s)}", ' for s in self.model.sym(sym_name) @@ -349,19 +360,19 @@ def jnp_array_str(array) -> str: for sym_name in ("p", "k", "y", "x") }, **{ + # in jax model we do not need to distinguish between p (parameters) and + # k (fixed parameters) so we use a single variable combining both "PK_SYMS": "".join( str(strip_pysb(s)) + ", " - for s in list(self.model.sym("p")) - + list(self.model.sym("k")) + for s in chain(self.model.sym("p"), self.model.sym("k")) ), "PK_IDS": "".join( f'"{strip_pysb(s)}", ' - for s in list(self.model.sym("p")) - + list(self.model.sym("k")) + for s in chain(self.model.sym("p"), self.model.sym("k")) ), - }, - **{ "MODEL_NAME": self.model_name, + # keep track of the API version that the model was generated with so we + # can flag conflicts in the future "MODEL_API_VERSION": f"'{JAXModel.MODEL_API_VERSION}'", }, } diff --git a/python/sdist/amici/jax/__init__.py b/python/sdist/amici/jax/__init__.py index e69de29bb2..7f8575e88e 100644 --- a/python/sdist/amici/jax/__init__.py +++ b/python/sdist/amici/jax/__init__.py @@ -0,0 +1 @@ +"""Interface to facilitate AMICI generated models using JAX""" diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index 9335d1a0a7..ceeea8d817 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -173,6 +173,7 @@ def _llh( ) -> jt.Float[jt.Scalar, ""]: """ Compute the log-likelihood of the observable for the specified observable index. + :param t: time point :param x: @@ -430,10 +431,11 @@ def simulate_condition( controller: diffrax.AbstractStepSizeController, adjoint: diffrax.AbstractAdjoint, max_steps: int | jnp.int_, - ret: str = "llh", + ret: str = "nllh", ): r""" Simulate a condition. + :param p: parameters for simulation ordered according to ids in :ivar parameter_ids: :param p_preeq: @@ -464,8 +466,8 @@ def simulate_condition( maximum number of solver steps :param ret: which output to return. Valid values are - - `llh`: negative log-likelihood (default) - - `llhs`: negative log-likelihoods at each time point + - `nllh`: negative log-likelihood (default) + - `llhs`: log-likelihoods at each time point - `x0`: full initial state vector (after pre-equilibration) - `x0_solver`: reduced initial state vector (after pre-equilibration) - `x`: full state vector @@ -532,9 +534,9 @@ def simulate_condition( x = jnp.concatenate((x_preq, x_dyn, x_posteq), axis=0) llhs = self._llhs(ts, x, p, tcl, my, iys) - llh = -jnp.sum(llhs) + nllh = -jnp.sum(llhs) return { - "llh": llh, + "nllh": nllh, "llhs": llhs, "x": self._x_rdatas(x, tcl), "x_solver": x, diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index aae83f410c..b1ee96e167 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -108,7 +108,7 @@ def _get_parameter_mappings( for sim_var, value in mapping.map_sim_var.items(): if isinstance(value, Number) and not np.isfinite(value): mapping.map_sim_var[sim_var] = 1.0 - return dict(zip(scs, mappings)) + return dict(zip(scs, mappings, strict=True)) def _get_measurements( self, simulation_conditions: pd.DataFrame @@ -117,7 +117,7 @@ def _get_measurements( tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray], ]: """ - Set measurements for the model based on the provided simulation conditions. + Get measurements for the model based on the provided simulation conditions. :param simulation_conditions: Simulation conditions to create parameter mappings for. Same format as returned by @@ -156,17 +156,13 @@ def _get_measurements( ) return measurements - def _get_nominal_parameter_values(self) -> jnp.ndarray: + def _get_nominal_parameter_values(self) -> jt.Float[jt.Array, "np"]: """ - Set the nominal parameter values for the model based on the nominal values in the PEtab problem. + Get the nominal parameter values for the model based on the nominal values in the PEtab problem. :return: - JAXModel instance with parameter values set to the nominal values. + jax array with nominal parameter values """ - if self._petab_problem is None: - raise ValueError( - "PEtab problem not set, cannot set nominal values." - ) return jnp.array( [ petab.scale( @@ -306,7 +302,7 @@ def run_simulations( icoeff=0.3, dcoeff=0.0, ), - max_steps: int = 2**14, + max_steps: int = 2**10, ): """ Run simulations for a problem. diff --git a/tests/benchmark-models/test_petab_benchmark.py b/tests/benchmark-models/test_petab_benchmark.py index 132402f3c8..7a0afc6832 100644 --- a/tests/benchmark-models/test_petab_benchmark.py +++ b/tests/benchmark-models/test_petab_benchmark.py @@ -9,13 +9,10 @@ from pathlib import Path import fiddy import amici -import equinox as eqx -import jax.numpy as jnp import numpy as np import pandas as pd import petab.v1 as petab import pytest -import jax from amici.petab.petab_import import import_petab_problem import benchmark_models_petab from collections import defaultdict @@ -37,11 +34,8 @@ rdatas_to_measurement_df, simulate_petab, ) -from amici.jax.petab import run_simulations, JAXProblem -from petab.v1.visualize import plot_problem -from beartype import beartype -jax.config.update("jax_enable_x64", True) +from petab.v1.visualize import plot_problem # Enable various debug output @@ -267,6 +261,14 @@ def benchmark_problem(request): "ignore:Adjoint sensitivity analysis for models with discontinuous ", ) def test_jax_llh(benchmark_problem): + import jax + import equinox as eqx + import jax.numpy as jnp + from amici.jax.petab import run_simulations, JAXProblem + + jax.config.update("jax_enable_x64", True) + from beartype import beartype + problem_id, petab_problem, amici_model = benchmark_problem if problem_id in (