Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax export #1861

Merged
merged 94 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from 79 commits
Commits
Show all changes
94 commits
Select commit Hold shift + click to select a range
328d462
basic prototype
FFroehlich Aug 25, 2022
ffa5afb
Merge branch 'develop' into jax_export
FFroehlich Aug 26, 2022
d4f8552
add dimerization example, add second order code, refactor jit
FFroehlich Aug 26, 2022
d37a850
remove equinox dependency, list dependencies
FFroehlich Aug 26, 2022
ff37c7e
make jax optional
FFroehlich Aug 26, 2022
c3a77f7
Merge branch 'develop' into jax_export
FFroehlich Aug 26, 2022
7cd8553
support conservation laws
FFroehlich Aug 26, 2022
5177ad7
fixup
FFroehlich Aug 26, 2022
5612cfc
fix jit nesting
FFroehlich Aug 26, 2022
2dd0377
use vmap for vectorization
FFroehlich Aug 26, 2022
e9bd14f
fixups
FFroehlich Aug 26, 2022
bbb5246
add multithreaded simulation runner
FFroehlich Aug 26, 2022
9bd1004
fix my
FFroehlich Aug 26, 2022
1b06c24
Merge branch 'develop' into jax_export
FFroehlich Sep 9, 2022
599aa71
fixes
FFroehlich Sep 13, 2022
51812d6
merge
FFroehlich Apr 10, 2024
3fbd17a
fixup merge
FFroehlich Apr 10, 2024
5974d47
fix install
FFroehlich Apr 10, 2024
37cdc81
actually generate code
FFroehlich Apr 10, 2024
9e6a0ff
fix
FFroehlich Apr 10, 2024
22b2b38
fix
FFroehlich Apr 10, 2024
48a2e49
add better default coefficients, fix jax
FFroehlich Apr 10, 2024
481216d
ignore fujita in jax
FFroehlich Apr 10, 2024
85b8173
ignore smith
FFroehlich Apr 10, 2024
b213adb
optimize & fix bachmann
FFroehlich Apr 11, 2024
a1f37b7
fix import/wokflow
FFroehlich Apr 11, 2024
e09bb2f
Update __init__.template.py
FFroehlich Apr 12, 2024
d8d1900
fix jax imports
FFroehlich Apr 12, 2024
c24fe6b
Update setup.cfg
FFroehlich Apr 12, 2024
1ec591c
add preequilibration support
FFroehlich Apr 12, 2024
aebe07c
fix jax tests
FFroehlich Apr 13, 2024
4125c51
add filterwarning
FFroehlich Apr 14, 2024
8143cc2
fix parameter transformation
FFroehlich Apr 14, 2024
781bb3b
Merge branch 'develop' into jax_export
FFroehlich Oct 19, 2024
81e2aeb
reenable ruff format
FFroehlich Oct 19, 2024
c01f707
post merge cleanup
FFroehlich Oct 19, 2024
a5d356a
"fix" splines
FFroehlich Oct 19, 2024
9a021cf
Update .pre-commit-config.yaml
FFroehlich Oct 19, 2024
a02d215
Merge branch 'develop' into jax_export
FFroehlich Oct 19, 2024
50193d8
force optimistix 0.0.9
FFroehlich Oct 21, 2024
d6c5bcd
Merge branch 'jax_export' of https://github.com/AMICI-dev/AMICI into …
FFroehlich Oct 21, 2024
7faae32
add support for heavyside functions
FFroehlich Oct 21, 2024
907acb7
cleanup & actually run tests
FFroehlich Oct 21, 2024
82a01ba
simply tests + add support for non-dynamic simulation in jax
FFroehlich Oct 22, 2024
7c3aef9
Merge branch 'develop' into jax_export
FFroehlich Oct 23, 2024
c548c93
fix for NONCONST_CLS
FFroehlich Oct 24, 2024
7c27a21
fix petab path
FFroehlich Oct 24, 2024
b84dbdb
Merge branch 'develop' into jax_export
FFroehlich Oct 24, 2024
37b9329
Merge branch 'develop' into jax_export
FFroehlich Oct 24, 2024
956b0a6
fixup merge
FFroehlich Oct 24, 2024
2f3834d
support postequilibration
FFroehlich Oct 25, 2024
5366632
fixup
FFroehlich Oct 25, 2024
5a86f4c
fix
FFroehlich Oct 25, 2024
480b75a
fix gradients
FFroehlich Oct 25, 2024
8b9c10a
fix hessian
FFroehlich Oct 25, 2024
7dc81ac
Update test_petab_benchmark.py
FFroehlich Oct 25, 2024
866c811
Merge branch 'develop' into jax_export
FFroehlich Oct 27, 2024
02a1272
skip smith in jax
FFroehlich Oct 27, 2024
51bd18c
exclude more models
FFroehlich Oct 27, 2024
c7c5d4b
refactor: remove use of edatas
FFroehlich Nov 9, 2024
a514deb
update template
FFroehlich Nov 9, 2024
498681a
Update .pre-commit-config.yaml
FFroehlich Nov 9, 2024
4a5e7d2
Merge branch 'develop' into jax_export
FFroehlich Nov 11, 2024
f745be0
fix python jax tests
FFroehlich Nov 12, 2024
a64f89b
simplify petab interface
FFroehlich Nov 12, 2024
7292451
add parameter values to model class
FFroehlich Nov 12, 2024
da02106
refactor parameter mapping
FFroehlich Nov 12, 2024
a46e65d
refactor & simplify
FFroehlich Nov 12, 2024
404d82e
refsctor
FFroehlich Nov 16, 2024
e399f4c
update template
FFroehlich Nov 16, 2024
eaae778
Update .pre-commit-config.yaml
FFroehlich Nov 16, 2024
d79cfc1
refactor fix test
FFroehlich Nov 16, 2024
94aa679
Update petab.py
FFroehlich Nov 16, 2024
b129c86
fixups
FFroehlich Nov 17, 2024
9b6a62b
fixup
FFroehlich Nov 17, 2024
74cd498
add documentation and typing
FFroehlich Nov 17, 2024
d94714b
add runtime typechecks to jax tests
FFroehlich Nov 17, 2024
0a9fcdf
add coverage from benchmark tests
FFroehlich Nov 17, 2024
186805c
add api versioning and reenable jit compilation
FFroehlich Nov 17, 2024
250f9dd
review comments
FFroehlich Nov 18, 2024
dc4992e
use temporary directories
FFroehlich Nov 18, 2024
d547509
fix doc
FFroehlich Nov 18, 2024
82bfe31
Update test_jax.py
FFroehlich Nov 18, 2024
a010803
don't generate code if jax/diffrax not available
FFroehlich Nov 18, 2024
d9ae05e
Merge branch 'develop' into jax_export
FFroehlich Nov 18, 2024
f7c2c10
add example
FFroehlich Nov 19, 2024
5dc8735
fix doc
FFroehlich Nov 19, 2024
784ab2c
fix notebook symlink
FFroehlich Nov 19, 2024
d528168
update notebook
FFroehlich Nov 19, 2024
24d8c09
Update ExampleJaxPEtab.ipynb
FFroehlich Nov 19, 2024
5393e6c
Update ExampleJaxPEtab.ipynb
FFroehlich Nov 19, 2024
a22f099
fix compilation issue
FFroehlich Nov 19, 2024
a585414
Merge branch 'develop' into jax_export
FFroehlich Nov 19, 2024
c242b15
fix
FFroehlich Nov 19, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions .github/workflows/test_benchmark_collection_models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ jobs:
run: |
pip3 install --user petab[vis] && \
AMICI_PARALLEL_COMPILE="" pip3 install -v --user \
$(ls -t python/sdist/dist/amici-*.tar.gz | head -1)[petab,test,vis]
$(ls -t python/sdist/dist/amici-*.tar.gz | head -1)[petab,test,vis,jax]

- name: Install test dependencies
run: |
Expand All @@ -60,14 +60,27 @@ jobs:

- name: Download benchmark collection
run: |
git clone --depth 1 https://github.com/benchmarking-initiative/Benchmark-Models-PEtab.git \
&& python3 -m pip install -e Benchmark-Models-PEtab/src/python
pip install git+https://github.com/Benchmarking-Initiative/Benchmark-Models-PEtab.git@master#subdirectory=src/python

- name: Run tests
env:
AMICI_PARALLEL_COMPILE: ""
run: |
cd tests/benchmark-models && pytest --durations=10
cd tests/benchmark-models && pytest \
--durations=10
--cov=amici \
--cov-report=xml:"coverage_py.xml" \
--cov-append \

- name: Codecov Python
if: github.event_name == 'pull_request' || github.repository_owner == 'AMICI-dev'
uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
file: coverage_py.xml
flags: python
fail_ci_if_error: true
verbose: true

# collect & upload results
- name: Aggregate results
Expand Down
5 changes: 0 additions & 5 deletions .github/workflows/test_python_cplusplus.yml
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,6 @@ jobs:
- name: Install python package
run: scripts/installAmiciSource.sh

- name: Install notebook dependencies
run: |
source venv/bin/activate \
&& pip install jax[cpu]

- name: example notebooks
run: scripts/runNotebook.sh python/examples/example_*/

Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ tests/test/*
*/tests/explicit_amici/*
*/tests/fixed_initial_amici/*
*/tests/localfunc_amici/*
*/tests/conversion/*
*/tests/dimerization/*
tests/cpp/writeResults.h5
tests/cpp/writeResults.h5.bak
tests/sbml-test-suite/*
Expand Down
2 changes: 2 additions & 0 deletions documentation/rtd_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ sphinx<8
mock>=5.0.2
setuptools>=67.7.2
pysb>=1.11.0
jax>=0.4.26
diffrax>=0.5.0
matplotlib==3.7.1
nbsphinx==0.9.1
nbformat==5.8.0
Expand Down
10 changes: 10 additions & 0 deletions python/sdist/amici/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@ def _imported_from_setup() -> bool:
assignmentRules2observables,
)

try:
from .jax import JAXModel
except (ImportError, ModuleNotFoundError):
JAXModel = object

@runtime_checkable
class ModelModule(Protocol): # noqa: F811
"""Type of AMICI-generated model modules.
Expand All @@ -135,6 +140,11 @@ def get_model(self) -> amici.Model:
"""Create a model instance."""
...

def get_jax_model(self) -> JAXModel:
...

AmiciModel = Union[amici.Model, amici.ModelPtr]


class add_path:
"""Context manager for temporarily changing PYTHONPATH"""
Expand Down
12 changes: 11 additions & 1 deletion python/sdist/amici/__init__.template.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""AMICI-generated module for model TPL_MODELNAME"""

from pathlib import Path

from typing import TYPE_CHECKING
import amici

if TYPE_CHECKING:
from amici.jax import JAXModel

# Ensure we are binary-compatible, see #556
if "TPL_AMICI_VERSION" != amici.__version__:
raise amici.AmiciVersionError(
Expand All @@ -18,4 +21,11 @@
from .TPL_MODELNAME import * # noqa: F403, F401
from .TPL_MODELNAME import getModel as get_model # noqa: F401


def get_jax_model() -> "JAXModel":
from .jax import JAXModel_TPL_MODELNAME

return JAXModel_TPL_MODELNAME()


__version__ = "TPL_PACKAGE_VERSION"
142 changes: 124 additions & 18 deletions python/sdist/amici/de_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
TYPE_CHECKING,
Literal,
)

import sympy as sp

from . import (
Expand Down Expand Up @@ -54,6 +55,7 @@
AmiciCxxCodePrinter,
get_switch_statement,
)
from .jaxcodeprinter import AmiciJaxCodePrinter
from .de_model import DEModel
from .de_model_components import *
from .import_utils import (
Expand Down Expand Up @@ -143,7 +145,10 @@ class DEExporter:
If the given model uses special functions, this set contains hints for
model building.

:ivar _code_printer:
:ivar _code_printer_jax:
Code printer to generate JAX code

:ivar _code_printer_cpp:
Code printer to generate C++ code

:ivar generate_sensitivity_code:
Expand Down Expand Up @@ -212,14 +217,15 @@ def __init__(
self.set_name(model_name)
self.set_paths(outdir)

self._code_printer = AmiciCxxCodePrinter()
self._code_printer_cpp = AmiciCxxCodePrinter()
self._code_printer_jax = AmiciJaxCodePrinter()
for fun in CUSTOM_FUNCTIONS:
self._code_printer.known_functions[fun["sympy"]] = fun["c++"]
self._code_printer_cpp.known_functions[fun["sympy"]] = fun["c++"]

# Signatures and properties of generated model functions (see
# include/amici/model.h for details)
self.model: DEModel = de_model
self._code_printer.known_functions.update(
self._code_printer_cpp.known_functions.update(
splines.spline_user_functions(
self.model._splines, self._get_index("p")
)
Expand All @@ -242,6 +248,7 @@ def generate_model_code(self) -> None:
sp.Pow, "_eval_derivative", _custom_pow_eval_derivative
):
self._prepare_model_folder()
self._generate_jax_code()
self._generate_c_code()
self._generate_m_code()

Expand Down Expand Up @@ -269,6 +276,105 @@ def _prepare_model_folder(self) -> None:
if os.path.isfile(file_path):
os.remove(file_path)

@log_execution_time("generating jax code", logger)
def _generate_jax_code(self) -> None:
from amici.jax.model import JAXModel

eq_names = (
"xdot",
"w",
"x0",
"y",
"sigmay",
"Jy",
"x_solver",
"x_rdata",
"total_cl",
)
sym_names = ("x", "tcl", "w", "my", "y", "sigmay", "x_rdata")

indent = 8

def jnp_array_str(array) -> str:
elems = ", ".join(str(s) for s in array)

return f"jnp.array([{elems}])"

tpl_data = {
**{
f"{eq_name.upper()}_EQ": "\n".join(
self._code_printer_jax._get_sym_lines(
(str(strip_pysb(s)) for s in self.model.sym(eq_name)),
self.model.eq(eq_name).subs(
dict(
zip(
list(self.model.sym("h"))
+ list(self.model.sym("my")),
[
sp.Heaviside(x)
for x in self.model.eq("root")
]
+ [sp.Symbol("my")]
* len(self.model.sym("my")),
)
)
),
indent,
)
)[indent:]
for eq_name in eq_names
},
FFroehlich marked this conversation as resolved.
Show resolved Hide resolved
**{
FFroehlich marked this conversation as resolved.
Show resolved Hide resolved
f"{eq_name.upper()}_RET": jnp_array_str(
strip_pysb(s) for s in self.model.sym(eq_name)
)
if self.model.sym(eq_name)
else "jnp.array([])"
for eq_name in eq_names
},
**{
f"{sym_name.upper()}_SYMS": "".join(
str(strip_pysb(s)) + ", " for s in self.model.sym(sym_name)
)
if self.model.sym(sym_name)
else "_"
for sym_name in sym_names
},
**{
f"{sym_name.upper()}_IDS": "".join(
f'"{strip_pysb(s)}", ' for s in self.model.sym(sym_name)
)
if self.model.sym(sym_name)
else "tuple()"
for sym_name in ("p", "k", "y", "x")
},
**{
"PK_SYMS": "".join(
FFroehlich marked this conversation as resolved.
Show resolved Hide resolved
str(strip_pysb(s)) + ", "
for s in list(self.model.sym("p"))
+ list(self.model.sym("k"))
FFroehlich marked this conversation as resolved.
Show resolved Hide resolved
),
"PK_IDS": "".join(
f'"{strip_pysb(s)}", '
for s in list(self.model.sym("p"))
+ list(self.model.sym("k"))
),
},
**{
"MODEL_NAME": self.model_name,
"MODEL_API_VERSION": f"'{JAXModel.MODEL_API_VERSION}'",
},
}
os.makedirs(
os.path.join(self.model_path, self.model_name), exist_ok=True
)

apply_template(
os.path.join(amiciModulePath, "jax.template.py"),
os.path.join(self.model_path, self.model_name, "jax.py"),
tpl_data,
)

def _generate_c_code(self) -> None:
"""
Create C++ code files for the model based on
Expand Down Expand Up @@ -729,7 +835,7 @@ def _get_function_body(
f"reinitialization_state_idxs.cend(), {index}) != "
"reinitialization_state_idxs.cend())",
f" {function}[{index}] = "
f"{self._code_printer.doprint(formula)};",
f"{self._code_printer_cpp.doprint(formula)};",
]
)
cases[ipar] = expressions
Expand All @@ -744,12 +850,12 @@ def _get_function_body(
f"reinitialization_state_idxs.cend(), {index}) != "
"reinitialization_state_idxs.cend())\n "
f"{function}[{index}] = "
f"{self._code_printer.doprint(formula)};"
f"{self._code_printer_cpp.doprint(formula)};"
)

elif function in event_functions:
cases = {
ie: self._code_printer._get_sym_lines_array(
ie: self._code_printer_cpp._get_sym_lines_array(
equations[ie], function, 0
)
for ie in range(self.model.num_events())
Expand All @@ -762,7 +868,7 @@ def _get_function_body(
for ie, inner_equations in enumerate(equations):
inner_lines = []
inner_cases = {
ipar: self._code_printer._get_sym_lines_array(
ipar: self._code_printer_cpp._get_sym_lines_array(
inner_equations[:, ipar], function, 0
)
for ipar in range(self.model.num_par())
Expand All @@ -777,7 +883,7 @@ def _get_function_body(
and equations.shape[1] == self.model.num_par()
):
cases = {
ipar: self._code_printer._get_sym_lines_array(
ipar: self._code_printer_cpp._get_sym_lines_array(
equations[:, ipar], function, 0
)
for ipar in range(self.model.num_par())
Expand All @@ -787,15 +893,15 @@ def _get_function_body(
elif function in multiobs_functions:
if function == "dJydy":
cases = {
iobs: self._code_printer._get_sym_lines_array(
iobs: self._code_printer_cpp._get_sym_lines_array(
equations[iobs], function, 0
)
for iobs in range(self.model.num_obs())
if not smart_is_zero_matrix(equations[iobs])
}
else:
cases = {
iobs: self._code_printer._get_sym_lines_array(
iobs: self._code_printer_cpp._get_sym_lines_array(
equations[:, iobs], function, 0
)
for iobs in range(equations.shape[1])
Expand Down Expand Up @@ -825,7 +931,7 @@ def _get_function_body(
tmp_equations = sp.Matrix(
[equations[i] for i in static_idxs]
)
tmp_lines = self._code_printer._get_sym_lines_symbols(
tmp_lines = self._code_printer_cpp._get_sym_lines_symbols(
tmp_symbols,
tmp_equations,
function,
Expand All @@ -851,7 +957,7 @@ def _get_function_body(
[equations[i] for i in dynamic_idxs]
)

tmp_lines = self._code_printer._get_sym_lines_symbols(
tmp_lines = self._code_printer_cpp._get_sym_lines_symbols(
tmp_symbols,
tmp_equations,
function,
Expand All @@ -863,12 +969,12 @@ def _get_function_body(
lines.extend(tmp_lines)

else:
lines += self._code_printer._get_sym_lines_symbols(
lines += self._code_printer_cpp._get_sym_lines_symbols(
symbols, equations, function, 4
)

else:
lines += self._code_printer._get_sym_lines_array(
lines += self._code_printer_cpp._get_sym_lines_array(
equations, function, 4
)

Expand Down Expand Up @@ -1024,10 +1130,10 @@ def _write_model_header_cpp(self) -> None:
"NK": self.model.num_const(),
"O2MODE": "amici::SecondOrderMode::none",
# using code printer ensures proper handling of nan/inf
"PARAMETERS": self._code_printer.doprint(self.model.val("p"))[
"PARAMETERS": self._code_printer_cpp.doprint(self.model.val("p"))[
1:-1
],
"FIXED_PARAMETERS": self._code_printer.doprint(
"FIXED_PARAMETERS": self._code_printer_cpp.doprint(
self.model.val("k")
)[1:-1],
"PARAMETER_NAMES_INITIALIZER_LIST": self._get_symbol_name_initializer_list(
Expand Down Expand Up @@ -1221,7 +1327,7 @@ def _get_symbol_id_initializer_list(self, name: str) -> str:
Template initializer list of ids
"""
return "\n".join(
f'"{self._code_printer.doprint(symbol)}", // {name}[{idx}]'
f'"{self._code_printer_cpp.doprint(symbol)}", // {name}[{idx}]'
for idx, symbol in enumerate(self.model.sym(name))
)

Expand Down
Loading
Loading