Skip to content

Commit

Permalink
add api versioning and reenable jit compilation
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Nov 17, 2024
1 parent 0a9fcdf commit 186805c
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 1 deletion.
3 changes: 3 additions & 0 deletions python/sdist/amici/de_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,8 @@ def _prepare_model_folder(self) -> None:

@log_execution_time("generating jax code", logger)
def _generate_jax_code(self) -> None:
from amici.jax.model import JAXModel

eq_names = (
"xdot",
"w",
Expand Down Expand Up @@ -360,6 +362,7 @@ def jnp_array_str(array) -> str:
},
**{
"MODEL_NAME": self.model_name,
"MODEL_API_VERSION": f"'{JAXModel.MODEL_API_VERSION}'",
},
}
os.makedirs(
Expand Down
2 changes: 2 additions & 0 deletions python/sdist/amici/jax.template.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@


class JAXModel_TPL_MODEL_NAME(JAXModel):
api_version = TPL_MODEL_API_VERSION

def __init__(self):
super().__init__()

Expand Down
12 changes: 11 additions & 1 deletion python/sdist/amici/jax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ class JAXModel(eqx.Module):
classes inheriting from JAXModel.
"""

MODEL_API_VERSION = "0.0.1"
api_version: str

def __init__(self):
if self.api_version != self.MODEL_API_VERSION:
raise ValueError(
"JAXModel API version mismatch, please regenerate the model class."
)
super().__init__()

@abstractmethod
def _xdot(
self,
Expand Down Expand Up @@ -406,7 +416,7 @@ def _sigmays(
in_axes=(0, 0, None, None, 0),
)(ts, xs, p, tcl, iys)

# @eqx.filter_jit
@eqx.filter_jit
def simulate_condition(
self,
p: jt.Float[jt.Array, "np"],
Expand Down

0 comments on commit 186805c

Please sign in to comment.