From 0d49041435f428e50af0dabddfc32c20015c4a1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Tue, 3 Dec 2024 12:01:06 +0000 Subject: [PATCH] Decouple JAX & C++ code generation (#2615) * refactor * correct doc * remove PK * fix tests * fix notebook * fix parameter ids * fix jax test * fix notebook * reviews * fixup --- python/sdist/amici/__init__.py | 2 - python/sdist/amici/__init__.template.py | 30 -- python/sdist/amici/de_export.py | 163 +------- python/sdist/amici/{ => jax}/jax.template.py | 36 +- .../sdist/amici/{ => jax}/jaxcodeprinter.py | 0 python/sdist/amici/jax/ode_export.py | 277 +++++++++++++ python/sdist/amici/petab/import_helpers.py | 14 +- python/sdist/amici/petab/petab_import.py | 19 +- python/sdist/amici/petab/pysb_import.py | 45 ++- python/sdist/amici/petab/sbml_import.py | 367 ++++++++++-------- python/sdist/amici/pysb_import.py | 106 ++++- python/sdist/amici/sbml_import.py | 121 +++++- python/tests/test_jax.py | 69 +++- 13 files changed, 832 insertions(+), 417 deletions(-) rename python/sdist/amici/{ => jax}/jax.template.py (71%) rename python/sdist/amici/{ => jax}/jaxcodeprinter.py (100%) create mode 100644 python/sdist/amici/jax/ode_export.py diff --git a/python/sdist/amici/__init__.py b/python/sdist/amici/__init__.py index 6788fefe77..0a7d3c6581 100644 --- a/python/sdist/amici/__init__.py +++ b/python/sdist/amici/__init__.py @@ -141,8 +141,6 @@ def get_model(self) -> amici.Model: """Create a model instance.""" ... - def get_jax_model(self) -> JAXModel: ... - AmiciModel = Union[amici.Model, amici.ModelPtr] else: ModelModule = ModuleType diff --git a/python/sdist/amici/__init__.template.py b/python/sdist/amici/__init__.template.py index efc8df0617..06302eba9d 100644 --- a/python/sdist/amici/__init__.template.py +++ b/python/sdist/amici/__init__.template.py @@ -1,16 +1,10 @@ """AMICI-generated module for model TPL_MODELNAME""" -import datetime -import os import sys from pathlib import Path -from typing import TYPE_CHECKING import amici -if TYPE_CHECKING: - from amici.jax import JAXModel - # Ensure we are binary-compatible, see #556 if "TPL_AMICI_VERSION" != amici.__version__: raise amici.AmiciVersionError( @@ -38,28 +32,4 @@ # when the model package is imported via `import` TPL_MODELNAME._model_module = sys.modules[__name__] - -def get_jax_model() -> "JAXModel": - # If the model directory was meanwhile overwritten, this would load the - # new version, which would not match the previously imported extension. - # This is not allowed, as it would lead to inconsistencies. - jax_py_file = Path(__file__).parent / "jax.py" - jax_py_file = jax_py_file.resolve() - t_imported = TPL_MODELNAME._get_import_time() # noqa: protected-access - t_modified = os.path.getmtime(jax_py_file) - if t_imported < t_modified: - t_imp_str = datetime.datetime.fromtimestamp(t_imported).isoformat() - t_mod_str = datetime.datetime.fromtimestamp(t_modified).isoformat() - raise RuntimeError( - f"Refusing to import {jax_py_file} which was changed since " - f"TPL_MODELNAME was imported. This is to avoid inconsistencies " - "between the different model implementations.\n" - f"Imported at {t_imp_str}\nModified at {t_mod_str}.\n" - "Import the module with a different name or restart the " - "Python kernel." - ) - jax = amici._module_from_path("jax", jax_py_file) - return jax.JAXModel_TPL_MODELNAME() - - __version__ = "TPL_PACKAGE_VERSION" diff --git a/python/sdist/amici/de_export.py b/python/sdist/amici/de_export.py index 416dec5694..f0ec08133f 100644 --- a/python/sdist/amici/de_export.py +++ b/python/sdist/amici/de_export.py @@ -21,7 +21,6 @@ TYPE_CHECKING, Literal, ) -from itertools import chain import sympy as sp @@ -56,7 +55,6 @@ AmiciCxxCodePrinter, get_switch_statement, ) -from .jaxcodeprinter import AmiciJaxCodePrinter from .de_model import DEModel from .de_model_components import * from .import_utils import ( @@ -146,10 +144,7 @@ class DEExporter: If the given model uses special functions, this set contains hints for model building. - :ivar _code_printer_jax: - Code printer to generate JAX code - - :ivar _code_printer_cpp: + :ivar _code_printer: Code printer to generate C++ code :ivar generate_sensitivity_code: @@ -218,15 +213,14 @@ def __init__( self.set_name(model_name) self.set_paths(outdir) - self._code_printer_cpp = AmiciCxxCodePrinter() - self._code_printer_jax = AmiciJaxCodePrinter() + self._code_printer = AmiciCxxCodePrinter() for fun in CUSTOM_FUNCTIONS: - self._code_printer_cpp.known_functions[fun["sympy"]] = fun["c++"] + self._code_printer.known_functions[fun["sympy"]] = fun["c++"] # Signatures and properties of generated model functions (see # include/amici/model.h for details) self.model: DEModel = de_model - self._code_printer_cpp.known_functions.update( + self._code_printer.known_functions.update( splines.spline_user_functions( self.model._splines, self._get_index("p") ) @@ -249,7 +243,6 @@ def generate_model_code(self) -> None: sp.Pow, "_eval_derivative", _custom_pow_eval_derivative ): self._prepare_model_folder() - self._generate_jax_code() self._generate_c_code() self._generate_m_code() @@ -277,121 +270,6 @@ def _prepare_model_folder(self) -> None: if os.path.isfile(file_path): os.remove(file_path) - @log_execution_time("generating jax code", logger) - def _generate_jax_code(self) -> None: - try: - from amici.jax.model import JAXModel - except ImportError: - logger.warning( - "Could not import JAXModel. JAX code will not be generated." - ) - return - - eq_names = ( - "xdot", - "w", - "x0", - "y", - "sigmay", - "Jy", - "x_solver", - "x_rdata", - "total_cl", - ) - sym_names = ("x", "tcl", "w", "my", "y", "sigmay", "x_rdata") - - indent = 8 - - def jnp_array_str(array) -> str: - elems = ", ".join(str(s) for s in array) - - return f"jnp.array([{elems}])" - - # replaces Heaviside variables with corresponding functions - subs_heaviside = dict( - zip( - self.model.sym("h"), - [sp.Heaviside(x) for x in self.model.eq("root")], - strict=True, - ) - ) - # replaces observables with a generic my variable - subs_observables = dict( - zip( - self.model.sym("my"), - [sp.Symbol("my")] * len(self.model.sym("my")), - strict=True, - ) - ) - - tpl_data = { - # assign named variable using corresponding algebraic formula (function body) - **{ - f"{eq_name.upper()}_EQ": "\n".join( - self._code_printer_jax._get_sym_lines( - (str(strip_pysb(s)) for s in self.model.sym(eq_name)), - self.model.eq(eq_name).subs( - {**subs_heaviside, **subs_observables} - ), - indent, - ) - )[indent:] # remove indent for first line - for eq_name in eq_names - }, - # create jax array from concatenation of named variables - **{ - f"{eq_name.upper()}_RET": jnp_array_str( - strip_pysb(s) for s in self.model.sym(eq_name) - ) - if self.model.sym(eq_name) - else "jnp.array([])" - for eq_name in eq_names - }, - # assign named variables from a jax array - **{ - f"{sym_name.upper()}_SYMS": "".join( - str(strip_pysb(s)) + ", " for s in self.model.sym(sym_name) - ) - if self.model.sym(sym_name) - else "_" - for sym_name in sym_names - }, - # tuple of variable names (ids as they are unique) - **{ - f"{sym_name.upper()}_IDS": "".join( - f'"{strip_pysb(s)}", ' for s in self.model.sym(sym_name) - ) - if self.model.sym(sym_name) - else "tuple()" - for sym_name in ("p", "k", "y", "x") - }, - **{ - # in jax model we do not need to distinguish between p (parameters) and - # k (fixed parameters) so we use a single variable combining both - "PK_SYMS": "".join( - str(strip_pysb(s)) + ", " - for s in chain(self.model.sym("p"), self.model.sym("k")) - ), - "PK_IDS": "".join( - f'"{strip_pysb(s)}", ' - for s in chain(self.model.sym("p"), self.model.sym("k")) - ), - "MODEL_NAME": self.model_name, - # keep track of the API version that the model was generated with so we - # can flag conflicts in the future - "MODEL_API_VERSION": f"'{JAXModel.MODEL_API_VERSION}'", - }, - } - os.makedirs( - os.path.join(self.model_path, self.model_name), exist_ok=True - ) - - apply_template( - os.path.join(amiciModulePath, "jax.template.py"), - os.path.join(self.model_path, self.model_name, "jax.py"), - tpl_data, - ) - def _generate_c_code(self) -> None: """ Create C++ code files for the model based on @@ -795,7 +673,7 @@ def _get_function_body( lines = [] if len(equations) == 0 or ( - isinstance(equations, (sp.Matrix, sp.ImmutableDenseMatrix)) + isinstance(equations, sp.Matrix | sp.ImmutableDenseMatrix) and min(equations.shape) == 0 ): # dJydy is a list @@ -852,7 +730,7 @@ def _get_function_body( f"reinitialization_state_idxs.cend(), {index}) != " "reinitialization_state_idxs.cend())", f" {function}[{index}] = " - f"{self._code_printer_cpp.doprint(formula)};", + f"{self._code_printer.doprint(formula)};", ] ) cases[ipar] = expressions @@ -867,12 +745,12 @@ def _get_function_body( f"reinitialization_state_idxs.cend(), {index}) != " "reinitialization_state_idxs.cend())\n " f"{function}[{index}] = " - f"{self._code_printer_cpp.doprint(formula)};" + f"{self._code_printer.doprint(formula)};" ) elif function in event_functions: cases = { - ie: self._code_printer_cpp._get_sym_lines_array( + ie: self._code_printer._get_sym_lines_array( equations[ie], function, 0 ) for ie in range(self.model.num_events()) @@ -885,7 +763,7 @@ def _get_function_body( for ie, inner_equations in enumerate(equations): inner_lines = [] inner_cases = { - ipar: self._code_printer_cpp._get_sym_lines_array( + ipar: self._code_printer._get_sym_lines_array( inner_equations[:, ipar], function, 0 ) for ipar in range(self.model.num_par()) @@ -900,7 +778,7 @@ def _get_function_body( and equations.shape[1] == self.model.num_par() ): cases = { - ipar: self._code_printer_cpp._get_sym_lines_array( + ipar: self._code_printer._get_sym_lines_array( equations[:, ipar], function, 0 ) for ipar in range(self.model.num_par()) @@ -910,7 +788,7 @@ def _get_function_body( elif function in multiobs_functions: if function == "dJydy": cases = { - iobs: self._code_printer_cpp._get_sym_lines_array( + iobs: self._code_printer._get_sym_lines_array( equations[iobs], function, 0 ) for iobs in range(self.model.num_obs()) @@ -918,7 +796,7 @@ def _get_function_body( } else: cases = { - iobs: self._code_printer_cpp._get_sym_lines_array( + iobs: self._code_printer._get_sym_lines_array( equations[:, iobs], function, 0 ) for iobs in range(equations.shape[1]) @@ -948,7 +826,7 @@ def _get_function_body( tmp_equations = sp.Matrix( [equations[i] for i in static_idxs] ) - tmp_lines = self._code_printer_cpp._get_sym_lines_symbols( + tmp_lines = self._code_printer._get_sym_lines_symbols( tmp_symbols, tmp_equations, function, @@ -974,7 +852,7 @@ def _get_function_body( [equations[i] for i in dynamic_idxs] ) - tmp_lines = self._code_printer_cpp._get_sym_lines_symbols( + tmp_lines = self._code_printer._get_sym_lines_symbols( tmp_symbols, tmp_equations, function, @@ -986,12 +864,12 @@ def _get_function_body( lines.extend(tmp_lines) else: - lines += self._code_printer_cpp._get_sym_lines_symbols( + lines += self._code_printer._get_sym_lines_symbols( symbols, equations, function, 4 ) else: - lines += self._code_printer_cpp._get_sym_lines_array( + lines += self._code_printer._get_sym_lines_array( equations, function, 4 ) @@ -1136,8 +1014,7 @@ def _write_model_header_cpp(self) -> None: ) ), "NDXDOTDX_EXPLICIT": len(self.model.sparsesym("dxdotdx_explicit")), - "NDJYDY": "std::vector{%s}" - % ",".join(str(len(x)) for x in self.model.sparsesym("dJydy")), + "NDJYDY": f"std::vector{{{','.join(str(len(x)) for x in self.model.sparsesym('dJydy'))}}}", "NDXRDATADXSOLVER": len(self.model.sparsesym("dx_rdatadx_solver")), "NDXRDATADTCL": len(self.model.sparsesym("dx_rdatadtcl")), "NDTOTALCLDXRDATA": len(self.model.sparsesym("dtotal_cldx_rdata")), @@ -1147,10 +1024,10 @@ def _write_model_header_cpp(self) -> None: "NK": self.model.num_const(), "O2MODE": "amici::SecondOrderMode::none", # using code printer ensures proper handling of nan/inf - "PARAMETERS": self._code_printer_cpp.doprint(self.model.val("p"))[ + "PARAMETERS": self._code_printer.doprint(self.model.val("p"))[ 1:-1 ], - "FIXED_PARAMETERS": self._code_printer_cpp.doprint( + "FIXED_PARAMETERS": self._code_printer.doprint( self.model.val("k") )[1:-1], "PARAMETER_NAMES_INITIALIZER_LIST": self._get_symbol_name_initializer_list( @@ -1344,7 +1221,7 @@ def _get_symbol_id_initializer_list(self, name: str) -> str: Template initializer list of ids """ return "\n".join( - f'"{self._code_printer_cpp.doprint(symbol)}", // {name}[{idx}]' + f'"{self._code_printer.doprint(symbol)}", // {name}[{idx}]' for idx, symbol in enumerate(self.model.sym(name)) ) diff --git a/python/sdist/amici/jax.template.py b/python/sdist/amici/jax/jax.template.py similarity index 71% rename from python/sdist/amici/jax.template.py rename to python/sdist/amici/jax/jax.template.py index ddddb8a64b..d395715422 100644 --- a/python/sdist/amici/jax.template.py +++ b/python/sdist/amici/jax/jax.template.py @@ -14,28 +14,28 @@ def __init__(self): super().__init__() def _xdot(self, t, x, args): - pk, tcl = args + p, tcl = args TPL_X_SYMS = x - TPL_PK_SYMS = pk + TPL_P_SYMS = p TPL_TCL_SYMS = tcl - TPL_W_SYMS = self._w(t, x, pk, tcl) + TPL_W_SYMS = self._w(t, x, p, tcl) TPL_XDOT_EQ return TPL_XDOT_RET - def _w(self, t, x, pk, tcl): + def _w(self, t, x, p, tcl): TPL_X_SYMS = x - TPL_PK_SYMS = pk + TPL_P_SYMS = p TPL_TCL_SYMS = tcl TPL_W_EQ return TPL_W_RET - def _x0(self, pk): - TPL_PK_SYMS = pk + def _x0(self, p): + TPL_P_SYMS = p TPL_X0_EQ @@ -56,25 +56,25 @@ def _x_rdata(self, x, tcl): return TPL_X_RDATA_RET - def _tcl(self, x, pk): + def _tcl(self, x, p): TPL_X_RDATA_SYMS = x - TPL_PK_SYMS = pk + TPL_P_SYMS = p TPL_TOTAL_CL_EQ return TPL_TOTAL_CL_RET - def _y(self, t, x, pk, tcl): + def _y(self, t, x, p, tcl): TPL_X_SYMS = x - TPL_PK_SYMS = pk - TPL_W_SYMS = self._w(t, x, pk, tcl) + TPL_P_SYMS = p + TPL_W_SYMS = self._w(t, x, p, tcl) TPL_Y_EQ return TPL_Y_RET - def _sigmay(self, y, pk): - TPL_PK_SYMS = pk + def _sigmay(self, y, p): + TPL_P_SYMS = p TPL_Y_SYMS = y @@ -82,10 +82,10 @@ def _sigmay(self, y, pk): return TPL_SIGMAY_RET - def _nllh(self, t, x, pk, tcl, my, iy): - y = self._y(t, x, pk, tcl) + def _nllh(self, t, x, p, tcl, my, iy): + y = self._y(t, x, p, tcl) TPL_Y_SYMS = y - TPL_SIGMAY_SYMS = self._sigmay(y, pk) + TPL_SIGMAY_SYMS = self._sigmay(y, p) TPL_JY_EQ @@ -101,7 +101,7 @@ def state_ids(self): @property def parameter_ids(self): - return TPL_PK_IDS + return TPL_P_IDS Model = JAXModel_TPL_MODEL_NAME diff --git a/python/sdist/amici/jaxcodeprinter.py b/python/sdist/amici/jax/jaxcodeprinter.py similarity index 100% rename from python/sdist/amici/jaxcodeprinter.py rename to python/sdist/amici/jax/jaxcodeprinter.py diff --git a/python/sdist/amici/jax/ode_export.py b/python/sdist/amici/jax/ode_export.py new file mode 100644 index 0000000000..7ea4a29d8a --- /dev/null +++ b/python/sdist/amici/jax/ode_export.py @@ -0,0 +1,277 @@ +""" +JAX Export +---------- +This module provides all necessary functionality to specify an ordinary +differential equation model and generate executable jax simulation code. +The user generally won't have to directly call any function from this module +as this will be done by +:py:func:`amici.pysb_import.pysb2jax`, +:py:func:`amici.sbml_import.SbmlImporter.sbml2jax` and +:py:func:`amici.petab_import.import_model`. +""" + +from __future__ import annotations +import logging +import os +from pathlib import Path + +import sympy as sp + +from amici import ( + amiciModulePath, +) + +from amici._codegen.template import apply_template +from amici.jax.jaxcodeprinter import AmiciJaxCodePrinter +from amici.jax.model import JAXModel +from amici.de_model import DEModel +from amici.de_export import is_valid_identifier +from amici.import_utils import ( + strip_pysb, +) +from amici.logging import get_logger, log_execution_time, set_log_level +from amici.sympy_utils import ( + _custom_pow_eval_derivative, + _monkeypatched, +) + +#: python log manager +logger = get_logger(__name__, logging.ERROR) + + +def _jax_variable_assignments( + model: DEModel, sym_names: tuple[str, ...] +) -> dict: + return { + f"{sym_name.upper()}_SYMS": "".join( + str(strip_pysb(s)) + ", " for s in model.sym(sym_name) + ) + if model.sym(sym_name) + else "_" + for sym_name in sym_names + } + + +def _jax_variable_equations( + model: DEModel, + code_printer: AmiciJaxCodePrinter, + eq_names: tuple[str, ...], + subs: dict, + indent: int = 8, +) -> dict: + return { + f"{eq_name.upper()}_EQ": "\n".join( + code_printer._get_sym_lines( + (str(strip_pysb(s)) for s in model.sym(eq_name)), + model.eq(eq_name).subs(subs), + indent, + ) + )[indent:] # remove indent for first line + for eq_name in eq_names + } + + +def _jax_return_variables( + model: DEModel, + eq_names: tuple[str, ...], +) -> dict: + return { + f"{eq_name.upper()}_RET": _jnp_array_str( + strip_pysb(s) for s in model.sym(eq_name) + ) + if model.sym(eq_name) + else "jnp.array([])" + for eq_name in eq_names + } + + +def _jax_variable_ids(model: DEModel, sym_names: tuple[str, ...]) -> dict: + return { + f"{sym_name.upper()}_IDS": "".join( + f'"{strip_pysb(s)}", ' for s in model.sym(sym_name) + ) + if model.sym(sym_name) + else "tuple()" + for sym_name in sym_names + } + + +def _jnp_array_str(array) -> str: + elems = ", ".join(str(s) for s in array) + + return f"jnp.array([{elems}])" + + +class ODEExporter: + """ + The ODEExporter class generates AMICI jax files for a model as + defined in symbolic expressions. + + :ivar model: + DE definition + + :ivar verbose: + more verbose output if True + + :ivar model_name: + name of the model that will be used for compilation + + :ivar model_path: + path to the generated model specific files + + :ivar _code_printer: + Code printer to generate JAX code + """ + + def __init__( + self, + ode_model: DEModel, + outdir: Path | str | None = None, + verbose: bool | int | None = False, + model_name: str | None = "model", + ): + """ + Generate AMICI jax files for the ODE provided to the constructor. + + :param ode_model: + DE model definition + + :param outdir: + see :meth:`amici.de_export.DEExporter.set_paths` + + :param verbose: + verbosity level for logging, ``True``/``False`` default to + :data:`logging.Error`/:data:`logging.DEBUG` + + :param model_name: + name of the model to be used during code generation + """ + set_log_level(logger, verbose) + + self.verbose: bool = logger.getEffectiveLevel() <= logging.DEBUG + + self.model_path: Path = Path() + + self.set_name(model_name) + self.set_paths(outdir) + + self.model: DEModel = ode_model + + self._code_printer = AmiciJaxCodePrinter() + + @log_execution_time("generating jax code", logger) + def generate_model_code(self) -> None: + """ + Generates the jax code for the loaded model + """ + with _monkeypatched( + sp.Pow, "_eval_derivative", _custom_pow_eval_derivative + ): + self._prepare_model_folder() + self._generate_jax_code() + + def _prepare_model_folder(self) -> None: + """ + Create model directory or remove all files if the output directory + already exists. + """ + self.model_path.mkdir(parents=True, exist_ok=True) + + for file in self.model_path.glob("*"): + if file.is_file(): + file.unlink() + + @log_execution_time("generating jax code", logger) + def _generate_jax_code(self) -> None: + eq_names = ( + "xdot", + "w", + "x0", + "y", + "sigmay", + "Jy", + "x_solver", + "x_rdata", + "total_cl", + ) + sym_names = ("p", "x", "tcl", "w", "my", "y", "sigmay", "x_rdata") + + indent = 8 + + # replaces Heaviside variables with corresponding functions + subs_heaviside = dict( + zip( + self.model.sym("h"), + [sp.Heaviside(x) for x in self.model.eq("root")], + strict=True, + ) + ) + # replaces observables with a generic my variable + subs_observables = dict( + zip( + self.model.sym("my"), + [sp.Symbol("my")] * len(self.model.sym("my")), + strict=True, + ) + ) + subs = subs_heaviside | subs_observables + + tpl_data = { + # assign named variable using corresponding algebraic formula (function body) + **_jax_variable_equations( + self.model, self._code_printer, eq_names, subs, indent + ), + # create jax array from concatenation of named variables + **_jax_return_variables(self.model, eq_names), + # assign named variables from a jax array + **_jax_variable_assignments(self.model, sym_names), + # tuple of variable names (ids as they are unique) + **_jax_variable_ids(self.model, ("p", "k", "y", "x")), + **{ + "MODEL_NAME": self.model_name, + # keep track of the API version that the model was generated with so we + # can flag conflicts in the future + "MODEL_API_VERSION": f"'{JAXModel.MODEL_API_VERSION}'", + }, + } + outdir = self.model_path / (self.model_name + "_jax") + outdir.mkdir(parents=True, exist_ok=True) + + apply_template( + Path(amiciModulePath) / "jax" / "jax.template.py", + outdir / "__init__.py", + tpl_data, + ) + + def set_paths(self, output_dir: str | Path | None = None) -> None: + """ + Set output paths for the model and create if necessary + + :param output_dir: + relative or absolute path where the generated model + code is to be placed. If ``None``, this will default to + ``amici-{self.model_name}`` in the current working directory. + will be created if it does not exist. + + """ + if output_dir is None: + output_dir = Path(os.getcwd()) / f"amici-{self.model_name}" + + self.model_path = Path(output_dir).resolve() + + def set_name(self, model_name: str) -> None: + """ + Sets the model name + + :param model_name: + name of the model (may only contain upper and lower case letters, + digits and underscores, and must not start with a digit) + """ + if not is_valid_identifier(model_name): + raise ValueError( + f"'{model_name}' is not a valid model name. " + "Model name may only contain upper and lower case letters, " + "digits and underscores, and must not start with a digit." + ) + + self.model_name = model_name diff --git a/python/sdist/amici/petab/import_helpers.py b/python/sdist/amici/petab/import_helpers.py index 19afe5b237..daa902efb0 100644 --- a/python/sdist/amici/petab/import_helpers.py +++ b/python/sdist/amici/petab/import_helpers.py @@ -131,18 +131,26 @@ def _create_model_name(folder: str | Path) -> str: return os.path.split(os.path.normpath(folder))[-1] -def _can_import_model(model_name: str, model_output_dir: str | Path) -> bool: +def _can_import_model( + model_name: str, model_output_dir: str | Path, jax: bool = False +) -> bool: """ Check whether a module of that name can already be imported. """ # try to import (in particular checks version) + suffix = "_jax" if jax else "" try: - model_module = amici.import_model_module(model_name, model_output_dir) + model_module = amici.import_model_module( + model_name + suffix, model_output_dir + ) except ModuleNotFoundError: return False # no need to (re-)compile - return hasattr(model_module, "getModel") + if jax: + return hasattr(model_module, "Model") + else: + return hasattr(model_module, "getModel") def get_fixed_parameters( diff --git a/python/sdist/amici/petab/petab_import.py b/python/sdist/amici/petab/petab_import.py index 87ec3fbfec..63bade9bb8 100644 --- a/python/sdist/amici/petab/petab_import.py +++ b/python/sdist/amici/petab/petab_import.py @@ -126,7 +126,7 @@ def import_petab_problem( # check if compilation necessary if compile_ or ( compile_ is None - and not _can_import_model(model_name, model_output_dir) + and not _can_import_model(model_name, model_output_dir, jax) ): # check if folder exists if os.listdir(model_output_dir) and not compile_: @@ -146,7 +146,7 @@ def import_petab_problem( petab_problem, model_name=model_name, model_output_dir=model_output_dir, - compile=kwargs.pop("compile", not jax), + jax=jax, **kwargs, ) else: @@ -155,19 +155,18 @@ def import_petab_problem( model_name=model_name, model_output_dir=model_output_dir, non_estimated_parameters_as_constants=non_estimated_parameters_as_constants, - compile=kwargs.pop("compile", not jax), + jax=jax, **kwargs, ) # import model - if not jax: - model_module = amici.import_model_module(model_name, model_output_dir) + suffix = "_jax" if jax else "" + model_module = amici.import_model_module( + model_name + suffix, model_output_dir + ) - else: - jax_model_module = amici._module_from_path( - "jax", Path(model_output_dir) / model_name / "jax.py" - ) - model = jax_model_module.Model() + if jax: + model = model_module.Model() logger.info( f"Successfully loaded jax model {model_name} " diff --git a/python/sdist/amici/petab/pysb_import.py b/python/sdist/amici/petab/pysb_import.py index aac3a8f330..32de3d6666 100644 --- a/python/sdist/amici/petab/pysb_import.py +++ b/python/sdist/amici/petab/pysb_import.py @@ -168,6 +168,7 @@ def import_model_pysb( model_output_dir: str | Path | None = None, verbose: bool | int | None = True, model_name: str | None = None, + jax: bool = False, **kwargs, ) -> None: """ @@ -186,6 +187,9 @@ def import_model_pysb( :param model_name: Name of the generated model module + :param jax: + Whether to generate JAX code instead of C++ code. + :param kwargs: Additional keyword arguments to be passed to :func:`amici.pysb_import.pysb2amici`. @@ -259,16 +263,31 @@ def import_model_pysb( petab_problem.observable_df ) - from amici.pysb_import import pysb2amici - - pysb2amici( - model=pysb_model, - output_dir=model_output_dir, - model_name=model_name, - verbose=True, - observables=observables, - sigmas=sigmas, - constant_parameters=constant_parameters, - noise_distributions=noise_distrs, - **kwargs, - ) + if jax: + from amici.pysb_import import pysb2jax + + pysb2jax( + model=pysb_model, + output_dir=model_output_dir, + model_name=model_name, + verbose=True, + observables=observables, + sigmas=sigmas, + noise_distributions=noise_distrs, + **kwargs, + ) + return + else: + from amici.pysb_import import pysb2amici + + pysb2amici( + model=pysb_model, + output_dir=model_output_dir, + model_name=model_name, + verbose=True, + observables=observables, + sigmas=sigmas, + constant_parameters=constant_parameters, + noise_distributions=noise_distrs, + **kwargs, + ) diff --git a/python/sdist/amici/petab/sbml_import.py b/python/sdist/amici/petab/sbml_import.py index 92009bf7cd..02a2c4e12c 100644 --- a/python/sdist/amici/petab/sbml_import.py +++ b/python/sdist/amici/petab/sbml_import.py @@ -26,6 +26,173 @@ logger = logging.getLogger(__name__) +def _workaround_initial_states( + petab_problem: petab.Problem, sbml_model: libsbml.Model, **kwargs +): + # TODO: to parameterize initial states or compartment sizes, we currently + # need initial assignments. if they occur in the condition table, we + # create a new parameter initial_${speciesOrCompartmentID}. + # feels dirty and should be changed (see also #924) + # + + # state variable IDs and initial values specified via the conditions' table + initial_states = get_states_in_condition_table(petab_problem) + # is there any condition that involves preequilibration? + requires_preequilibration = ( + petab_problem.measurement_df is not None + and petab.PREEQUILIBRATION_CONDITION_ID in petab_problem.measurement_df + and petab_problem.measurement_df[petab.PREEQUILIBRATION_CONDITION_ID] + .notnull() + .any() + ) + estimated_parameters_ids = petab_problem.get_x_ids(free=True, fixed=False) + # any initial states overridden to be estimated via the conditions table? + has_estimated_initial_states = any( + par_id in petab_problem.condition_df[initial_states.keys()].values + for par_id in estimated_parameters_ids + ) + + if ( + has_estimated_initial_states + and requires_preequilibration + and kwargs.setdefault("generate_sensitivity_code", True) + ): + # To support reinitialization of initial conditions after + # preequilibration we need fixed parameters for the initial + # conditions. If we need sensitivities w.r.t. to initial conditions, + # we need to create non-fixed parameters for the initial conditions. + # We can't have both for the same state variable. + # (We could handle it via separate amici models if pre-equilibration + # and estimation of initial values for a given state variable are + # used in separate PEtab conditions.) + # We currently assume that we do need sensitivities w.r.t. initial + # conditions if sensitivities are needed at all. + # TODO: check this state by state, then we can support some additional + # cases + raise NotImplementedError( + "PEtab problems that have both, estimated initial conditions " + "specified in the condition table, and preequilibration with " + "initial conditions specified in the condition table are not " + "supported." + ) + + fixed_parameters = [] + if initial_states and requires_preequilibration: + # add preequilibration indicator variable + if sbml_model.getParameter(PREEQ_INDICATOR_ID) is not None: + raise AssertionError( + "Model already has a parameter with ID " + f"{PREEQ_INDICATOR_ID}. Cannot handle " + "species and compartments in condition table " + "then." + ) + indicator = sbml_model.createParameter() + indicator.setId(PREEQ_INDICATOR_ID) + indicator.setName(PREEQ_INDICATOR_ID) + # Can only reset parameters after preequilibration if they are fixed. + fixed_parameters.append(PREEQ_INDICATOR_ID) + logger.debug( + "Adding preequilibration indicator " + f"constant {PREEQ_INDICATOR_ID}" + ) + logger.debug( + f"Adding initial assignments for {list(initial_states.keys())}" + ) + for assignee_id in initial_states: + init_par_id_preeq = f"initial_{assignee_id}_preeq" + init_par_id_sim = f"initial_{assignee_id}_sim" + for init_par_id in ( + [init_par_id_preeq] if requires_preequilibration else [] + ) + [init_par_id_sim]: + if sbml_model.getElementBySId(init_par_id) is not None: + raise ValueError( + "Cannot create parameter for initial assignment " + f"for {assignee_id} because an entity named " + f"{init_par_id} exists already in the model." + ) + init_par = sbml_model.createParameter() + init_par.setId(init_par_id) + init_par.setName(init_par_id) + if requires_preequilibration: + # must be a fixed parameter to allow reinitialization + # TODO: also add other initial condition parameters that are + # not estimated + fixed_parameters.append(init_par_id) + + assignment = sbml_model.getInitialAssignment(assignee_id) + if assignment is None: + assignment = sbml_model.createInitialAssignment() + assignment.setSymbol(assignee_id) + else: + logger.debug( + "The SBML model has an initial assignment defined " + f"for model entity {assignee_id}, but this entity " + "also has an initial value defined in the PEtab " + "condition table. The SBML initial assignment will " + "be overwritten to handle preequilibration and " + "initial values specified by the PEtab problem." + ) + if requires_preequilibration: + formula = ( + f"{PREEQ_INDICATOR_ID} * {init_par_id_preeq} " + f"+ (1 - {PREEQ_INDICATOR_ID}) * {init_par_id_sim}" + ) + else: + formula = init_par_id_sim + math_ast = libsbml.parseL3Formula(formula) + assignment.setMath(math_ast) + # + + return fixed_parameters + + +def _workaround_observable_parameters( + observables, sigmas, sbml_model, output_parameter_defaults +): + # TODO: adding extra output parameters is currently not supported, + # so we add any output parameters to the SBML model. + # this should be changed to something more elegant + # + formulas = chain( + (val["formula"] for val in observables.values()), sigmas.values() + ) + output_parameters = OrderedDict() + for formula in formulas: + # we want reproducible parameter ordering upon repeated import + free_syms = sorted( + sp.sympify(formula, locals=_clash).free_symbols, + key=lambda symbol: symbol.name, + ) + for free_sym in free_syms: + sym = str(free_sym) + if ( + sbml_model.getElementBySId(sym) is None + and sym != "time" + and sym not in observables + ): + output_parameters[sym] = None + logger.debug( + "Adding output parameters to model: " + f"{list(output_parameters.keys())}" + ) + output_parameter_defaults = output_parameter_defaults or {} + if extra_pars := ( + set(output_parameter_defaults) - set(output_parameters.keys()) + ): + raise ValueError( + f"Default output parameter values were given for {extra_pars}, " + "but they those are not output parameters." + ) + + for par in output_parameters.keys(): + _add_global_parameter( + sbml_model=sbml_model, + parameter_id=par, + value=output_parameter_defaults.get(par, 0.0), + ) + # + + @log_execution_time("Importing PEtab model", logger) def import_model_sbml( sbml_model: Union[str, Path, "libsbml.Model"] = None, @@ -38,6 +205,7 @@ def import_model_sbml( non_estimated_parameters_as_constants=True, output_parameter_defaults: dict[str, float] | None = None, discard_sbml_annotations: bool = False, + jax: bool = False, **kwargs, ) -> amici.SbmlImporter: """ @@ -83,6 +251,9 @@ def import_model_sbml( :param discard_sbml_annotations: Discard information contained in AMICI SBML annotations (debug). + :param jax: + Whether to generate JAX code instead of C++ code. + :param kwargs: Additional keyword arguments to be passed to :meth:`amici.sbml_import.SbmlImporter.sbml2amici`. @@ -111,7 +282,7 @@ def import_model_sbml( # Model name from SBML ID or filename if model_name is None: if not (model_name := petab_problem.model.sbml_model.getId()): - if not isinstance(sbml_model, (str, Path)): + if not isinstance(sbml_model, str | Path): raise ValueError( "No `model_name` was provided and no model " "ID was specified in the SBML model." @@ -174,162 +345,14 @@ def import_model_sbml( f"({len(sigmas)}) do not match." ) - # TODO: adding extra output parameters is currently not supported, - # so we add any output parameters to the SBML model. - # this should be changed to something more elegant - # - formulas = chain( - (val["formula"] for val in observables.values()), sigmas.values() - ) - output_parameters = OrderedDict() - for formula in formulas: - # we want reproducible parameter ordering upon repeated import - free_syms = sorted( - sp.sympify(formula, locals=_clash).free_symbols, - key=lambda symbol: symbol.name, - ) - for free_sym in free_syms: - sym = str(free_sym) - if ( - sbml_model.getElementBySId(sym) is None - and sym != "time" - and sym not in observables - ): - output_parameters[sym] = None - logger.debug( - "Adding output parameters to model: " - f"{list(output_parameters.keys())}" - ) - output_parameter_defaults = output_parameter_defaults or {} - if extra_pars := ( - set(output_parameter_defaults) - set(output_parameters.keys()) - ): - raise ValueError( - f"Default output parameter values were given for {extra_pars}, " - "but they those are not output parameters." - ) - - for par in output_parameters.keys(): - _add_global_parameter( - sbml_model=sbml_model, - parameter_id=par, - value=output_parameter_defaults.get(par, 0.0), - ) - # - - # TODO: to parameterize initial states or compartment sizes, we currently - # need initial assignments. if they occur in the condition table, we - # create a new parameter initial_${speciesOrCompartmentID}. - # feels dirty and should be changed (see also #924) - # - - # state variable IDs and initial values specified via the conditions' table - initial_states = get_states_in_condition_table(petab_problem) - # is there any condition that involves preequilibration? - requires_preequilibration = ( - petab_problem.measurement_df is not None - and petab.PREEQUILIBRATION_CONDITION_ID in petab_problem.measurement_df - and petab_problem.measurement_df[petab.PREEQUILIBRATION_CONDITION_ID] - .notnull() - .any() - ) - estimated_parameters_ids = petab_problem.get_x_ids(free=True, fixed=False) - # any initial states overridden to be estimated via the conditions table? - has_estimated_initial_states = any( - par_id in petab_problem.condition_df[initial_states.keys()].values - for par_id in estimated_parameters_ids + _workaround_observable_parameters( + observables, sigmas, sbml_model, output_parameter_defaults ) - - if ( - has_estimated_initial_states - and requires_preequilibration - and kwargs.setdefault("generate_sensitivity_code", True) - ): - # To support reinitialization of initial conditions after - # preequilibration we need fixed parameters for the initial - # conditions. If we need sensitivities w.r.t. to initial conditions, - # we need to create non-fixed parameters for the initial conditions. - # We can't have both for the same state variable. - # (We could handle it via separate amici models if pre-equilibration - # and estimation of initial values for a given state variable are - # used in separate PEtab conditions.) - # We currently assume that we do need sensitivities w.r.t. initial - # conditions if sensitivities are needed at all. - # TODO: check this state by state, then we can support some additional - # cases - raise NotImplementedError( - "PEtab problems that have both, estimated initial conditions " - "specified in the condition table, and preequilibration with " - "initial conditions specified in the condition table are not " - "supported." - ) - - fixed_parameters = [] - if initial_states and requires_preequilibration: - # add preequilibration indicator variable - if sbml_model.getParameter(PREEQ_INDICATOR_ID) is not None: - raise AssertionError( - "Model already has a parameter with ID " - f"{PREEQ_INDICATOR_ID}. Cannot handle " - "species and compartments in condition table " - "then." - ) - indicator = sbml_model.createParameter() - indicator.setId(PREEQ_INDICATOR_ID) - indicator.setName(PREEQ_INDICATOR_ID) - # Can only reset parameters after preequilibration if they are fixed. - fixed_parameters.append(PREEQ_INDICATOR_ID) - logger.debug( - "Adding preequilibration indicator " - f"constant {PREEQ_INDICATOR_ID}" - ) - logger.debug( - f"Adding initial assignments for {list(initial_states.keys())}" + fixed_parameters = _workaround_initial_states( + petab_problem=petab_problem, + sbml_model=sbml_model, + **kwargs, ) - for assignee_id in initial_states: - init_par_id_preeq = f"initial_{assignee_id}_preeq" - init_par_id_sim = f"initial_{assignee_id}_sim" - for init_par_id in ( - [init_par_id_preeq] if requires_preequilibration else [] - ) + [init_par_id_sim]: - if sbml_model.getElementBySId(init_par_id) is not None: - raise ValueError( - "Cannot create parameter for initial assignment " - f"for {assignee_id} because an entity named " - f"{init_par_id} exists already in the model." - ) - init_par = sbml_model.createParameter() - init_par.setId(init_par_id) - init_par.setName(init_par_id) - if requires_preequilibration: - # must be a fixed parameter to allow reinitialization - # TODO: also add other initial condition parameters that are - # not estimated - fixed_parameters.append(init_par_id) - - assignment = sbml_model.getInitialAssignment(assignee_id) - if assignment is None: - assignment = sbml_model.createInitialAssignment() - assignment.setSymbol(assignee_id) - else: - logger.debug( - "The SBML model has an initial assignment defined " - f"for model entity {assignee_id}, but this entity " - "also has an initial value defined in the PEtab " - "condition table. The SBML initial assignment will " - "be overwritten to handle preequilibration and " - "initial values specified by the PEtab problem." - ) - if requires_preequilibration: - formula = ( - f"{PREEQ_INDICATOR_ID} * {init_par_id_preeq} " - f"+ (1 - {PREEQ_INDICATOR_ID}) * {init_par_id_sim}" - ) - else: - formula = init_par_id_sim - math_ast = libsbml.parseL3Formula(formula) - assignment.setMath(math_ast) - # fixed_parameters.extend( _get_fixed_parameters_sbml( @@ -346,17 +369,29 @@ def import_model_sbml( ) # Create Python module from SBML model - sbml_importer.sbml2amici( - model_name=model_name, - output_dir=model_output_dir, - observables=observables, - constant_parameters=fixed_parameters, - sigmas=sigmas, - allow_reinit_fixpar_initcond=allow_reinit_fixpar_initcond, - noise_distributions=noise_distrs, - verbose=verbose, - **kwargs, - ) + if jax: + sbml_importer.sbml2jax( + model_name=model_name, + output_dir=model_output_dir, + observables=observables, + sigmas=sigmas, + noise_distributions=noise_distrs, + verbose=verbose, + **kwargs, + ) + return sbml_importer + else: + sbml_importer.sbml2amici( + model_name=model_name, + output_dir=model_output_dir, + observables=observables, + constant_parameters=fixed_parameters, + sigmas=sigmas, + allow_reinit_fixpar_initcond=allow_reinit_fixpar_initcond, + noise_distributions=noise_distrs, + verbose=verbose, + **kwargs, + ) if kwargs.get( "compile", diff --git a/python/sdist/amici/pysb_import.py b/python/sdist/amici/pysb_import.py index a273759536..b84fadea44 100644 --- a/python/sdist/amici/pysb_import.py +++ b/python/sdist/amici/pysb_import.py @@ -12,7 +12,6 @@ from pathlib import Path from typing import ( Any, - Union, ) from collections.abc import Callable from collections.abc import Iterable @@ -45,11 +44,112 @@ from .logging import get_logger, log_execution_time, set_log_level CL_Prototype = dict[str, dict[str, Any]] -ConservationLaw = dict[str, Union[dict, str, sp.Basic]] +ConservationLaw = dict[str, dict | str | sp.Basic] logger = get_logger(__name__, logging.ERROR) +def pysb2jax( + model: pysb.Model, + output_dir: str | Path | None = None, + observables: list[str] = None, + sigmas: dict[str, str] = None, + noise_distributions: dict[str, str | Callable] | None = None, + verbose: int | bool = False, + compute_conservation_laws: bool = True, + simplify: Callable = _default_simplify, + # Do not enable by default without testing. + # See https://github.com/AMICI-dev/AMICI/pull/1672 + cache_simplify: bool = False, + model_name: str | None = None, +): + r""" + Generate AMICI jax files for the provided model. + + .. warning:: + **PySB models with Compartments** + + When importing a PySB model with ``pysb.Compartment``\ s, BioNetGen + scales reaction fluxes with the compartment size. Instead of using the + respective symbols, the compartment size Parameter or Expression is + evaluated when generating equations. This may lead to unexpected + results if the compartment size parameter is changed for AMICI + simulations. + + :param model: + pysb model, :attr:`pysb.Model.name` will determine the name of the + generated module + + :param output_dir: + see :meth:`amici.de_export.ODEExporter.set_paths` + + :param observables: + list of :class:`pysb.core.Expression` or :class:`pysb.core.Observable` + names in the provided model that should be mapped to observables + + :param sigmas: + dict of :class:`pysb.core.Expression` names that should be mapped to + sigmas + + :param noise_distributions: + dict with names of observable Expressions as keys and a noise type + identifier, or a callable generating a custom noise formula string + (see :py:func:`amici.import_utils.noise_distribution_to_cost_function` + ). If nothing is passed for some observable id, a normal model is + assumed as default. + + :param verbose: verbosity level for logging, True/False default to + :attr:`logging.DEBUG`/:attr:`logging.ERROR` + + :param compute_conservation_laws: + if set to ``True``, conservation laws are automatically computed and + applied such that the state-jacobian of the ODE right-hand-side has + full rank. This option should be set to ``True`` when using the Newton + algorithm to compute steadystates + + :param simplify: + see :attr:`amici.DEModel._simplify` + + :param cache_simplify: + see :func:`amici.DEModel.__init__` + Note that there are possible issues with PySB models: + https://github.com/AMICI-dev/AMICI/pull/1672 + + :param model_name: + Name for the generated model module. If None, :attr:`pysb.Model.name` + will be used. + """ + if observables is None: + observables = [] + + if sigmas is None: + sigmas = {} + + model_name = model_name or model.name + + set_log_level(logger, verbose) + ode_model = ode_model_from_pysb_importer( + model, + observables=observables, + sigmas=sigmas, + noise_distributions=noise_distributions, + compute_conservation_laws=compute_conservation_laws, + simplify=simplify, + cache_simplify=cache_simplify, + verbose=verbose, + ) + + from amici.jax.ode_export import ODEExporter + + exporter = ODEExporter( + ode_model, + outdir=output_dir, + model_name=model_name, + verbose=verbose, + ) + exporter.generate_model_code() + + def pysb2amici( model: pysb.Model, output_dir: str | Path | None = None, @@ -180,7 +280,7 @@ def pysb2amici( # Sympy code optimizations are incompatible with PySB objects, as # `pysb.Observable` comes with its own `.match` which overrides # `sympy.Basic.match()`, breaking `sympy.codegen.rewriting.optimize`. - exporter._code_printer_cpp._fpoptimizer = None + exporter._code_printer._fpoptimizer = None exporter.generate_model_code() if compile: diff --git a/python/sdist/amici/sbml_import.py b/python/sdist/amici/sbml_import.py index fcaa1ed752..557ad02d0f 100644 --- a/python/sdist/amici/sbml_import.py +++ b/python/sdist/amici/sbml_import.py @@ -16,7 +16,6 @@ from pathlib import Path from typing import ( Any, - Union, ) from collections.abc import Callable from collections.abc import Iterable, Sequence @@ -63,7 +62,7 @@ default_symbols = {symbol: {} for symbol in SymbolId} -ConservationLaw = dict[str, Union[str, sp.Expr]] +ConservationLaw = dict[str, str | sp.Expr] logger = get_logger(__name__, logging.ERROR) @@ -447,6 +446,110 @@ def sbml2amici( ) exporter.compile_model() + def sbml2jax( + self, + model_name: str, + output_dir: str | Path = None, + observables: dict[str, dict[str, str]] = None, + sigmas: dict[str, str | float] = None, + noise_distributions: dict[str, str | Callable] = None, + verbose: int | bool = logging.ERROR, + compute_conservation_laws: bool = True, + simplify: Callable | None = _default_simplify, + cache_simplify: bool = False, + log_as_log10: bool = True, + ) -> None: + """ + Generate and compile AMICI jax files for the model provided to the + constructor. + + The resulting model can be imported as a regular Python module. + + Note that this generates model ODEs for changes in concentrations, not + amounts unless the `hasOnlySubstanceUnits` attribute has been + defined for a particular species. + + :param model_name: + Name of the generated model package. + Note that in a given Python session, only one model with a given + name can be loaded at a time. + The generated Python extensions cannot be unloaded. Therefore, + make sure to choose a unique name for each model. + + :param output_dir: + Directory where the generated model package will be stored. + + :param observables: + Observables to be added to the model: + ``dictionary( observableId:{'name':observableName + (optional), 'formula':formulaString)})``. + + :param sigmas: + dictionary(observableId: sigma value or (existing) parameter name) + + :param noise_distributions: + dictionary(observableId: noise type). + If nothing is passed for some observable id, a normal model is + assumed as default. Either pass a noise type identifier, or a + callable generating a custom noise string. + For noise identifiers, see + :func:`amici.import_utils.noise_distribution_to_cost_function`. + + :param verbose: + verbosity level for logging, ``True``/``False`` default to + ``logging.Error``/``logging.DEBUG`` + + :param compute_conservation_laws: + if set to ``True``, conservation laws are automatically computed + and applied such that the state-jacobian of the ODE + right-hand-side has full rank. This option should be set to + ``True`` when using the Newton algorithm to compute steadystate + sensitivities. + Conservation laws for constant species are enabled by default. + Support for conservation laws for non-constant species is + experimental and may be enabled by setting an environment variable + ``AMICI_EXPERIMENTAL_SBML_NONCONST_CLS`` to either ``demartino`` + to use the algorithm proposed by De Martino et al. (2014) + https://doi.org/10.1371/journal.pone.0100750, or to any other value + to use the deterministic algorithm implemented in + ``conserved_moieties2.py``. In some cases, the ``demartino`` may + run for a very long time. This has been observed for example in the + case of stoichiometric coefficients with many significant digits. + + :param simplify: + see :attr:`amici.ODEModel._simplify` + + :param cache_simplify: + see :meth:`amici.ODEModel.__init__` + + :param log_as_log10: + If ``True``, log in the SBML model will be parsed as ``log10`` + (default), if ``False``, log will be parsed as natural logarithm + ``ln``. + """ + set_log_level(logger, verbose) + + ode_model = self._build_ode_model( + observables=observables, + sigmas=sigmas, + noise_distributions=noise_distributions, + verbose=verbose, + compute_conservation_laws=compute_conservation_laws, + simplify=simplify, + cache_simplify=cache_simplify, + log_as_log10=log_as_log10, + ) + + from amici.jax.ode_export import ODEExporter + + exporter = ODEExporter( + ode_model, + model_name=model_name, + outdir=output_dir, + verbose=verbose, + ) + exporter.generate_model_code() + def _build_ode_model( self, observables: dict[str, dict[str, str]] = None, @@ -719,7 +822,7 @@ def check_support(self) -> None: rule.isRate() and not isinstance( self.sbml.getElementBySId(rule.getVariable()), - (sbml.Compartment, sbml.Species, sbml.Parameter), + sbml.Compartment | sbml.Species | sbml.Parameter, ) for rule in self.sbml.getListOfRules() ): @@ -1143,8 +1246,8 @@ def _process_parameters( for parameter in constant_parameters: if not self.sbml.getParameter(parameter): raise KeyError( - "Cannot make %s a constant parameter: " - "Parameter does not exist." % parameter + f"Cannot make {parameter} a constant parameter: " + "Parameter does not exist." ) # parameter ID => initial assignment sympy expression @@ -2880,16 +2983,14 @@ def _parse_event_trigger(trigger: sp.Expr) -> sp.Expr: # convert relational expressions into trigger functions if isinstance( trigger, - (sp.core.relational.LessThan, sp.core.relational.StrictLessThan), + sp.core.relational.LessThan | sp.core.relational.StrictLessThan, ): # y < x or y <= x return -root if isinstance( trigger, - ( - sp.core.relational.GreaterThan, - sp.core.relational.StrictGreaterThan, - ), + sp.core.relational.GreaterThan + | sp.core.relational.StrictGreaterThan, ): # y >= x or y > x return root diff --git a/python/tests/test_jax.py b/python/tests/test_jax.py index 30e205ca26..8f4c68510b 100644 --- a/python/tests/test_jax.py +++ b/python/tests/test_jax.py @@ -12,7 +12,7 @@ import numpy as np from beartype import beartype -from amici.pysb_import import pysb2amici +from amici.pysb_import import pysb2amici, pysb2jax from amici.testing import TemporaryDirectoryWinSafe, skip_on_valgrind from amici.petab.petab_import import import_petab_problem from amici.jax import JAXProblem @@ -39,17 +39,21 @@ def test_conversion(): pysb.Rule("conv", a(s="a") >> a(s="b"), pysb.Parameter("kcat", 0.05)) pysb.Observable("ab", a(s="b")) - with TemporaryDirectoryWinSafe(prefix=model.name) as outdir: + with TemporaryDirectoryWinSafe() as outdir: pysb2amici(model, outdir, verbose=True, observables=["ab"]) + pysb2jax(model, outdir, verbose=True, observables=["ab"]) - model_module = amici.import_model_module( + amici_module = amici.import_model_module( module_name=model.name, module_path=outdir ) + jax_module = amici.import_model_module( + module_name=model.name + "_jax", module_path=outdir + ) ts = tuple(np.linspace(0, 1, 10)) p = jnp.stack((1.0, 0.1), axis=-1) k = tuple() - _test_model(model_module, ts, p, k) + _test_model(amici_module, jax_module, ts, p, k) @skip_on_valgrind @@ -86,7 +90,7 @@ def test_dimerization(): pysb.Observable("a_obs", a()) pysb.Observable("b_obs", b()) - with TemporaryDirectoryWinSafe(prefix=model.name) as outdir: + with TemporaryDirectoryWinSafe() as outdir: pysb2amici( model, outdir, @@ -94,26 +98,34 @@ def test_dimerization(): observables=["a_obs", "b_obs"], constant_parameters=["ksyn_a", "ksyn_b"], ) + pysb2jax( + model, + outdir, + observables=["a_obs", "b_obs"], + ) - model_module = amici.import_model_module( + amici_module = amici.import_model_module( module_name=model.name, module_path=outdir ) + jax_module = amici.import_model_module( + module_name=model.name + "_jax", module_path=outdir + ) ts = tuple(np.linspace(0, 1, 10)) p = jnp.stack((5, 0.5, 0.5, 0.5), axis=-1) k = (0.5, 5) - _test_model(model_module, ts, p, k) + _test_model(amici_module, jax_module, ts, p, k) -def _test_model(model_module, ts, p, k): - amici_model = model_module.getModel() +def _test_model(amici_module, jax_module, ts, p, k): + amici_model = amici_module.getModel() amici_model.setTimepoints(np.asarray(ts, dtype=np.float64)) sol_amici_ref = amici.runAmiciSimulation( amici_model, amici_model.getSolver() ) - jax_model = model_module.get_jax_model() + jax_model = jax_module.Model() amici_model.setParameters(np.asarray(p, dtype=np.float64)) amici_model.setFixedParameters(np.asarray(k, dtype=np.float64)) @@ -129,12 +141,19 @@ def _test_model(model_module, ts, p, k): rs_amici = amici.runAmiciSimulations(amici_model, amici_solver, [edata]) check_fields_jax( - rs_amici, jax_model, edata, ["x", "y", "llh", "res", "x0"] + rs_amici, + jax_model, + amici_model.getParameterIds(), + amici_model.getFixedParameterIds(), + edata, + ["x", "y", "llh", "res", "x0"], ) check_fields_jax( rs_amici, jax_model, + amici_model.getParameterIds(), + amici_model.getFixedParameterIds(), edata, ["sllh", "sx0", "sx", "sres", "sy"], sensi_order=amici.SensitivityOrder.first, @@ -144,6 +163,8 @@ def _test_model(model_module, ts, p, k): def check_fields_jax( rs_amici, jax_model, + parameter_ids, + fixed_parameter_ids, edata, fields, sensi_order=amici.SensitivityOrder.none, @@ -160,7 +181,13 @@ def check_fields_jax( ts_preeq = ts[ts == 0] ts_dyn = ts[ts > 0] ts_posteq = np.array([]) - p = jnp.array(list(edata.parameters) + list(edata.fixedParameters)) + + par_dict = { + **dict(zip(parameter_ids, edata.parameters)), + **dict(zip(fixed_parameter_ids, edata.fixedParameters)), + } + + p = jnp.array([par_dict[par_id] for par_id in jax_model.parameter_ids]) args = ( jnp.array([]), # p_preeq jnp.array(ts_preeq), # ts_preeq @@ -187,6 +214,10 @@ def check_fields_jax( 0 ] + amici_par_idx = np.array( + [jax_model.parameter_ids.index(par_id) for par_id in parameter_ids] + ) + for field in fields: for r_amici, r_jax in zip(rs_amici, [r_jax]): actual = r_jax[field] @@ -199,16 +230,16 @@ def check_fields_jax( axis=1, ) elif field == "sllh": - actual = actual[: len(edata.parameters)] + actual = actual[amici_par_idx] elif field == "sx": - actual = np.permute_dims( - actual[iys == 0, :, : len(edata.parameters)], (0, 2, 1) - ) + actual = actual[:, :, amici_par_idx] + actual = np.permute_dims(actual[iys == 0, :, :], (0, 2, 1)) elif field == "sy": + actual = actual[:, amici_par_idx] actual = np.permute_dims( np.stack( [ - actual[iys == iy, : len(edata.parameters)] + actual[iys == iy, :] for iy in sorted(np.unique(iys)) ], axis=1, @@ -216,9 +247,9 @@ def check_fields_jax( (0, 2, 1), ) elif field == "sx0": - actual = actual[:, : len(edata.parameters)].T + actual = actual[:, amici_par_idx].T elif field == "sres": - actual = actual[:, : len(edata.parameters)] + actual = actual[:, amici_par_idx] assert_allclose( actual=actual,