Skip to content

Commit

Permalink
Decouple JAX & C++ code generation (#2615)
Browse files Browse the repository at this point in the history
* refactor

* correct doc

* remove PK

* fix tests

* fix notebook

* fix parameter ids

* fix jax test

* fix notebook

* reviews

* fixup
  • Loading branch information
FFroehlich authored Dec 3, 2024
1 parent bd3bd91 commit 0d49041
Show file tree
Hide file tree
Showing 13 changed files with 832 additions and 417 deletions.
2 changes: 0 additions & 2 deletions python/sdist/amici/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,6 @@ def get_model(self) -> amici.Model:
"""Create a model instance."""
...

def get_jax_model(self) -> JAXModel: ...

AmiciModel = Union[amici.Model, amici.ModelPtr]
else:
ModelModule = ModuleType
Expand Down
30 changes: 0 additions & 30 deletions python/sdist/amici/__init__.template.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
"""AMICI-generated module for model TPL_MODELNAME"""

import datetime
import os
import sys
from pathlib import Path
from typing import TYPE_CHECKING
import amici


if TYPE_CHECKING:
from amici.jax import JAXModel

# Ensure we are binary-compatible, see #556
if "TPL_AMICI_VERSION" != amici.__version__:
raise amici.AmiciVersionError(
Expand Down Expand Up @@ -38,28 +32,4 @@
# when the model package is imported via `import`
TPL_MODELNAME._model_module = sys.modules[__name__]


def get_jax_model() -> "JAXModel":
# If the model directory was meanwhile overwritten, this would load the
# new version, which would not match the previously imported extension.
# This is not allowed, as it would lead to inconsistencies.
jax_py_file = Path(__file__).parent / "jax.py"
jax_py_file = jax_py_file.resolve()
t_imported = TPL_MODELNAME._get_import_time() # noqa: protected-access
t_modified = os.path.getmtime(jax_py_file)
if t_imported < t_modified:
t_imp_str = datetime.datetime.fromtimestamp(t_imported).isoformat()
t_mod_str = datetime.datetime.fromtimestamp(t_modified).isoformat()
raise RuntimeError(
f"Refusing to import {jax_py_file} which was changed since "
f"TPL_MODELNAME was imported. This is to avoid inconsistencies "
"between the different model implementations.\n"
f"Imported at {t_imp_str}\nModified at {t_mod_str}.\n"
"Import the module with a different name or restart the "
"Python kernel."
)
jax = amici._module_from_path("jax", jax_py_file)
return jax.JAXModel_TPL_MODELNAME()


__version__ = "TPL_PACKAGE_VERSION"
163 changes: 20 additions & 143 deletions python/sdist/amici/de_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
TYPE_CHECKING,
Literal,
)
from itertools import chain

import sympy as sp

Expand Down Expand Up @@ -56,7 +55,6 @@
AmiciCxxCodePrinter,
get_switch_statement,
)
from .jaxcodeprinter import AmiciJaxCodePrinter
from .de_model import DEModel
from .de_model_components import *
from .import_utils import (
Expand Down Expand Up @@ -146,10 +144,7 @@ class DEExporter:
If the given model uses special functions, this set contains hints for
model building.
:ivar _code_printer_jax:
Code printer to generate JAX code
:ivar _code_printer_cpp:
:ivar _code_printer:
Code printer to generate C++ code
:ivar generate_sensitivity_code:
Expand Down Expand Up @@ -218,15 +213,14 @@ def __init__(
self.set_name(model_name)
self.set_paths(outdir)

self._code_printer_cpp = AmiciCxxCodePrinter()
self._code_printer_jax = AmiciJaxCodePrinter()
self._code_printer = AmiciCxxCodePrinter()
for fun in CUSTOM_FUNCTIONS:
self._code_printer_cpp.known_functions[fun["sympy"]] = fun["c++"]
self._code_printer.known_functions[fun["sympy"]] = fun["c++"]

# Signatures and properties of generated model functions (see
# include/amici/model.h for details)
self.model: DEModel = de_model
self._code_printer_cpp.known_functions.update(
self._code_printer.known_functions.update(
splines.spline_user_functions(
self.model._splines, self._get_index("p")
)
Expand All @@ -249,7 +243,6 @@ def generate_model_code(self) -> None:
sp.Pow, "_eval_derivative", _custom_pow_eval_derivative
):
self._prepare_model_folder()
self._generate_jax_code()
self._generate_c_code()
self._generate_m_code()

Expand Down Expand Up @@ -277,121 +270,6 @@ def _prepare_model_folder(self) -> None:
if os.path.isfile(file_path):
os.remove(file_path)

@log_execution_time("generating jax code", logger)
def _generate_jax_code(self) -> None:
try:
from amici.jax.model import JAXModel
except ImportError:
logger.warning(
"Could not import JAXModel. JAX code will not be generated."
)
return

eq_names = (
"xdot",
"w",
"x0",
"y",
"sigmay",
"Jy",
"x_solver",
"x_rdata",
"total_cl",
)
sym_names = ("x", "tcl", "w", "my", "y", "sigmay", "x_rdata")

indent = 8

def jnp_array_str(array) -> str:
elems = ", ".join(str(s) for s in array)

return f"jnp.array([{elems}])"

# replaces Heaviside variables with corresponding functions
subs_heaviside = dict(
zip(
self.model.sym("h"),
[sp.Heaviside(x) for x in self.model.eq("root")],
strict=True,
)
)
# replaces observables with a generic my variable
subs_observables = dict(
zip(
self.model.sym("my"),
[sp.Symbol("my")] * len(self.model.sym("my")),
strict=True,
)
)

tpl_data = {
# assign named variable using corresponding algebraic formula (function body)
**{
f"{eq_name.upper()}_EQ": "\n".join(
self._code_printer_jax._get_sym_lines(
(str(strip_pysb(s)) for s in self.model.sym(eq_name)),
self.model.eq(eq_name).subs(
{**subs_heaviside, **subs_observables}
),
indent,
)
)[indent:] # remove indent for first line
for eq_name in eq_names
},
# create jax array from concatenation of named variables
**{
f"{eq_name.upper()}_RET": jnp_array_str(
strip_pysb(s) for s in self.model.sym(eq_name)
)
if self.model.sym(eq_name)
else "jnp.array([])"
for eq_name in eq_names
},
# assign named variables from a jax array
**{
f"{sym_name.upper()}_SYMS": "".join(
str(strip_pysb(s)) + ", " for s in self.model.sym(sym_name)
)
if self.model.sym(sym_name)
else "_"
for sym_name in sym_names
},
# tuple of variable names (ids as they are unique)
**{
f"{sym_name.upper()}_IDS": "".join(
f'"{strip_pysb(s)}", ' for s in self.model.sym(sym_name)
)
if self.model.sym(sym_name)
else "tuple()"
for sym_name in ("p", "k", "y", "x")
},
**{
# in jax model we do not need to distinguish between p (parameters) and
# k (fixed parameters) so we use a single variable combining both
"PK_SYMS": "".join(
str(strip_pysb(s)) + ", "
for s in chain(self.model.sym("p"), self.model.sym("k"))
),
"PK_IDS": "".join(
f'"{strip_pysb(s)}", '
for s in chain(self.model.sym("p"), self.model.sym("k"))
),
"MODEL_NAME": self.model_name,
# keep track of the API version that the model was generated with so we
# can flag conflicts in the future
"MODEL_API_VERSION": f"'{JAXModel.MODEL_API_VERSION}'",
},
}
os.makedirs(
os.path.join(self.model_path, self.model_name), exist_ok=True
)

apply_template(
os.path.join(amiciModulePath, "jax.template.py"),
os.path.join(self.model_path, self.model_name, "jax.py"),
tpl_data,
)

def _generate_c_code(self) -> None:
"""
Create C++ code files for the model based on
Expand Down Expand Up @@ -795,7 +673,7 @@ def _get_function_body(
lines = []

if len(equations) == 0 or (
isinstance(equations, (sp.Matrix, sp.ImmutableDenseMatrix))
isinstance(equations, sp.Matrix | sp.ImmutableDenseMatrix)
and min(equations.shape) == 0
):
# dJydy is a list
Expand Down Expand Up @@ -852,7 +730,7 @@ def _get_function_body(
f"reinitialization_state_idxs.cend(), {index}) != "
"reinitialization_state_idxs.cend())",
f" {function}[{index}] = "
f"{self._code_printer_cpp.doprint(formula)};",
f"{self._code_printer.doprint(formula)};",
]
)
cases[ipar] = expressions
Expand All @@ -867,12 +745,12 @@ def _get_function_body(
f"reinitialization_state_idxs.cend(), {index}) != "
"reinitialization_state_idxs.cend())\n "
f"{function}[{index}] = "
f"{self._code_printer_cpp.doprint(formula)};"
f"{self._code_printer.doprint(formula)};"
)

elif function in event_functions:
cases = {
ie: self._code_printer_cpp._get_sym_lines_array(
ie: self._code_printer._get_sym_lines_array(
equations[ie], function, 0
)
for ie in range(self.model.num_events())
Expand All @@ -885,7 +763,7 @@ def _get_function_body(
for ie, inner_equations in enumerate(equations):
inner_lines = []
inner_cases = {
ipar: self._code_printer_cpp._get_sym_lines_array(
ipar: self._code_printer._get_sym_lines_array(
inner_equations[:, ipar], function, 0
)
for ipar in range(self.model.num_par())
Expand All @@ -900,7 +778,7 @@ def _get_function_body(
and equations.shape[1] == self.model.num_par()
):
cases = {
ipar: self._code_printer_cpp._get_sym_lines_array(
ipar: self._code_printer._get_sym_lines_array(
equations[:, ipar], function, 0
)
for ipar in range(self.model.num_par())
Expand All @@ -910,15 +788,15 @@ def _get_function_body(
elif function in multiobs_functions:
if function == "dJydy":
cases = {
iobs: self._code_printer_cpp._get_sym_lines_array(
iobs: self._code_printer._get_sym_lines_array(
equations[iobs], function, 0
)
for iobs in range(self.model.num_obs())
if not smart_is_zero_matrix(equations[iobs])
}
else:
cases = {
iobs: self._code_printer_cpp._get_sym_lines_array(
iobs: self._code_printer._get_sym_lines_array(
equations[:, iobs], function, 0
)
for iobs in range(equations.shape[1])
Expand Down Expand Up @@ -948,7 +826,7 @@ def _get_function_body(
tmp_equations = sp.Matrix(
[equations[i] for i in static_idxs]
)
tmp_lines = self._code_printer_cpp._get_sym_lines_symbols(
tmp_lines = self._code_printer._get_sym_lines_symbols(
tmp_symbols,
tmp_equations,
function,
Expand All @@ -974,7 +852,7 @@ def _get_function_body(
[equations[i] for i in dynamic_idxs]
)

tmp_lines = self._code_printer_cpp._get_sym_lines_symbols(
tmp_lines = self._code_printer._get_sym_lines_symbols(
tmp_symbols,
tmp_equations,
function,
Expand All @@ -986,12 +864,12 @@ def _get_function_body(
lines.extend(tmp_lines)

else:
lines += self._code_printer_cpp._get_sym_lines_symbols(
lines += self._code_printer._get_sym_lines_symbols(
symbols, equations, function, 4
)

else:
lines += self._code_printer_cpp._get_sym_lines_array(
lines += self._code_printer._get_sym_lines_array(
equations, function, 4
)

Expand Down Expand Up @@ -1136,8 +1014,7 @@ def _write_model_header_cpp(self) -> None:
)
),
"NDXDOTDX_EXPLICIT": len(self.model.sparsesym("dxdotdx_explicit")),
"NDJYDY": "std::vector<int>{%s}"
% ",".join(str(len(x)) for x in self.model.sparsesym("dJydy")),
"NDJYDY": f"std::vector<int>{{{','.join(str(len(x)) for x in self.model.sparsesym('dJydy'))}}}",
"NDXRDATADXSOLVER": len(self.model.sparsesym("dx_rdatadx_solver")),
"NDXRDATADTCL": len(self.model.sparsesym("dx_rdatadtcl")),
"NDTOTALCLDXRDATA": len(self.model.sparsesym("dtotal_cldx_rdata")),
Expand All @@ -1147,10 +1024,10 @@ def _write_model_header_cpp(self) -> None:
"NK": self.model.num_const(),
"O2MODE": "amici::SecondOrderMode::none",
# using code printer ensures proper handling of nan/inf
"PARAMETERS": self._code_printer_cpp.doprint(self.model.val("p"))[
"PARAMETERS": self._code_printer.doprint(self.model.val("p"))[
1:-1
],
"FIXED_PARAMETERS": self._code_printer_cpp.doprint(
"FIXED_PARAMETERS": self._code_printer.doprint(
self.model.val("k")
)[1:-1],
"PARAMETER_NAMES_INITIALIZER_LIST": self._get_symbol_name_initializer_list(
Expand Down Expand Up @@ -1344,7 +1221,7 @@ def _get_symbol_id_initializer_list(self, name: str) -> str:
Template initializer list of ids
"""
return "\n".join(
f'"{self._code_printer_cpp.doprint(symbol)}", // {name}[{idx}]'
f'"{self._code_printer.doprint(symbol)}", // {name}[{idx}]'
for idx, symbol in enumerate(self.model.sym(name))
)

Expand Down
Loading

0 comments on commit 0d49041

Please sign in to comment.