From 186805c8f3d891ea7fa621e1b4b7336cce39d3f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Sun, 17 Nov 2024 20:23:55 +0000 Subject: [PATCH] add api versioning and reenable jit compilation --- python/sdist/amici/de_export.py | 3 +++ python/sdist/amici/jax.template.py | 2 ++ python/sdist/amici/jax/model.py | 12 +++++++++++- 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/python/sdist/amici/de_export.py b/python/sdist/amici/de_export.py index 823f5f8ca1..1bace90510 100644 --- a/python/sdist/amici/de_export.py +++ b/python/sdist/amici/de_export.py @@ -278,6 +278,8 @@ def _prepare_model_folder(self) -> None: @log_execution_time("generating jax code", logger) def _generate_jax_code(self) -> None: + from amici.jax.model import JAXModel + eq_names = ( "xdot", "w", @@ -360,6 +362,7 @@ def jnp_array_str(array) -> str: }, **{ "MODEL_NAME": self.model_name, + "MODEL_API_VERSION": f"'{JAXModel.MODEL_API_VERSION}'", }, } os.makedirs( diff --git a/python/sdist/amici/jax.template.py b/python/sdist/amici/jax.template.py index 67a9decf07..05d82288d5 100644 --- a/python/sdist/amici/jax.template.py +++ b/python/sdist/amici/jax.template.py @@ -5,6 +5,8 @@ class JAXModel_TPL_MODEL_NAME(JAXModel): + api_version = TPL_MODEL_API_VERSION + def __init__(self): super().__init__() diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index 22f994229d..9335d1a0a7 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -18,6 +18,16 @@ class JAXModel(eqx.Module): classes inheriting from JAXModel. """ + MODEL_API_VERSION = "0.0.1" + api_version: str + + def __init__(self): + if self.api_version != self.MODEL_API_VERSION: + raise ValueError( + "JAXModel API version mismatch, please regenerate the model class." + ) + super().__init__() + @abstractmethod def _xdot( self, @@ -406,7 +416,7 @@ def _sigmays( in_axes=(0, 0, None, None, 0), )(ts, xs, p, tcl, iys) - # @eqx.filter_jit + @eqx.filter_jit def simulate_condition( self, p: jt.Float[jt.Array, "np"],