Skip to content

Commit

Permalink
fixup merge
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Dec 3, 2024
1 parent eea558b commit b8632f1
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 2 deletions.
3 changes: 1 addition & 2 deletions python/sdist/amici/jax/jax.template.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self):
super().__init__()

def _xdot(self, t, x, args):
pk, tcl = args
p, tcl = args

TPL_X_SYMS = x
TPL_P_SYMS = p
Expand All @@ -31,7 +31,6 @@ def _xdot(self, t, x, args):

return TPL_XDOT_RET


def _w(self, t, x, p, tcl):
TPL_X_SYMS = x
TPL_P_SYMS = p
Expand Down
22 changes: 22 additions & 0 deletions python/sdist/amici/jax/ode_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from amici._codegen.template import apply_template
from amici.jax.jaxcodeprinter import AmiciJaxCodePrinter
from amici.jax.model import JAXModel
from amici.jax.nn import generate_equinox
from amici.de_model import DEModel
from amici.de_export import is_valid_identifier
from amici.import_utils import (
Expand Down Expand Up @@ -129,6 +130,7 @@ def __init__(
outdir: Path | str | None = None,
verbose: bool | int | None = False,
model_name: str | None = "model",
hybridisation: dict[str, str] = {},
):
"""
Generate AMICI jax files for the ODE provided to the constructor.
Expand Down Expand Up @@ -157,6 +159,8 @@ def __init__(

self.model: DEModel = ode_model

self.hybridisation = hybridisation

Check warning on line 162 in python/sdist/amici/jax/ode_export.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/ode_export.py#L162

Added line #L162 was not covered by tests

self._code_printer = AmiciJaxCodePrinter()

@log_execution_time("generating jax code", logger)
Expand All @@ -169,6 +173,7 @@ def generate_model_code(self) -> None:
):
self._prepare_model_folder()
self._generate_jax_code()
self._generate_nn_code()

Check warning on line 176 in python/sdist/amici/jax/ode_export.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/ode_export.py#L176

Added line #L176 was not covered by tests

def _prepare_model_folder(self) -> None:
"""
Expand Down Expand Up @@ -233,6 +238,14 @@ def _generate_jax_code(self) -> None:
# can flag conflicts in the future
"MODEL_API_VERSION": f"'{JAXModel.MODEL_API_VERSION}'",
},
"NET_IMPORTS": "\n".join(
f"{net} = _module_from_path('{net}', Path(__file__).parent / '{net}.py')"
for net in self.hybridisation.keys()
),
"NETS": ",\n".join(
f'"{net}": {net}.net(jr.PRNGKey(0))'
for net in self.hybridisation.keys()
),
}
outdir = self.model_path / (self.model_name + "_jax")
outdir.mkdir(parents=True, exist_ok=True)
Expand All @@ -243,6 +256,15 @@ def _generate_jax_code(self) -> None:
tpl_data,
)

def _generate_nn_code(self) -> None:
for net_name, net in self.hybridisation.items():
generate_equinox(

Check warning on line 261 in python/sdist/amici/jax/ode_export.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/ode_export.py#L259-L261

Added lines #L259 - L261 were not covered by tests
net["model"],
os.path.join(
self.model_path, self.model_name + "_jax", f"{net_name}.py"
),
)

def set_paths(self, output_dir: str | Path | None = None) -> None:
"""
Set output paths for the model and create if necessary
Expand Down
2 changes: 2 additions & 0 deletions python/sdist/amici/sbml_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,7 @@ def sbml2jax(
simplify: Callable | None = _default_simplify,
cache_simplify: bool = False,
log_as_log10: bool = True,
hybridisation: dict = None,
) -> None:
"""
Generate and compile AMICI jax files for the model provided to the
Expand Down Expand Up @@ -549,6 +550,7 @@ def sbml2jax(
model_name=model_name,
outdir=output_dir,
verbose=verbose,
hybridisation=hybridisation,
)
exporter.generate_model_code()

Expand Down

0 comments on commit b8632f1

Please sign in to comment.