diff --git a/python/sdist/amici/jax.template.py b/python/sdist/amici/jax.template.py index 367ba9e500..9b566281ca 100644 --- a/python/sdist/amici/jax.template.py +++ b/python/sdist/amici/jax.template.py @@ -1,5 +1,7 @@ +# ruff: noqa: F401, F821, F841 import jax.numpy as jnp from interpax import interp1d +from pathlib import Path from amici.jax.model import JAXModel @@ -8,10 +10,10 @@ class JAXModel_TPL_MODEL_NAME(JAXModel): api_version = TPL_MODEL_API_VERSION def __init__(self): + self.jax_py_file = Path(__file__).resolve() super().__init__() def _xdot(self, t, x, args): - pk, tcl = args TPL_X_SYMS = x @@ -24,7 +26,6 @@ def _xdot(self, t, x, args): return TPL_XDOT_RET def _w(self, t, x, pk, tcl): - TPL_X_SYMS = x TPL_PK_SYMS = pk TPL_TCL_SYMS = tcl @@ -34,7 +35,6 @@ def _w(self, t, x, pk, tcl): return TPL_W_RET def _x0(self, pk): - TPL_PK_SYMS = pk TPL_X0_EQ @@ -42,7 +42,6 @@ def _x0(self, pk): return TPL_X0_RET def _x_solver(self, x): - TPL_X_RDATA_SYMS = x TPL_X_SOLVER_EQ @@ -50,7 +49,6 @@ def _x_solver(self, x): return TPL_X_SOLVER_RET def _x_rdata(self, x, tcl): - TPL_X_SYMS = x TPL_TCL_SYMS = tcl @@ -59,7 +57,6 @@ def _x_rdata(self, x, tcl): return TPL_X_RDATA_RET def _tcl(self, x, pk): - TPL_X_RDATA_SYMS = x TPL_PK_SYMS = pk @@ -68,7 +65,6 @@ def _tcl(self, x, pk): return TPL_TOTAL_CL_RET def _y(self, t, x, pk, tcl): - TPL_X_SYMS = x TPL_PK_SYMS = pk TPL_W_SYMS = self._w(t, x, pk, tcl) @@ -86,7 +82,6 @@ def _sigmay(self, y, pk): return TPL_SIGMAY_RET - def _nllh(self, t, x, pk, tcl, my, iy): y = self._y(t, x, pk, tcl) TPL_Y_SYMS = y @@ -107,3 +102,6 @@ def state_ids(self): @property def parameter_ids(self): return TPL_PK_IDS + + +Model = JAXModel_TPL_MODEL_NAME diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index a7b274027a..e037c44a2f 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -3,6 +3,7 @@ # ruff: noqa: F821 F722 from abc import abstractmethod +from pathlib import Path import diffrax import equinox as eqx @@ -18,8 +19,9 @@ class JAXModel(eqx.Module): classes inheriting from JAXModel. """ - MODEL_API_VERSION = "0.0.1" + MODEL_API_VERSION = "0.0.2" api_version: str + jax_py_file: Path def __init__(self): if self.api_version != self.MODEL_API_VERSION: diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index 6ddfb7c074..2c823259fe 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -1,7 +1,8 @@ """PEtab wrappers for JAX models.""" "" - +import shutil from numbers import Number from collections.abc import Iterable +from pathlib import Path import diffrax import equinox as eqx @@ -12,6 +13,7 @@ import pandas as pd import petab.v1 as petab +from amici import _module_from_path from amici.petab.parameter_mapping import ( ParameterMappingForCondition, create_parameter_mapping, @@ -84,6 +86,45 @@ def __init__(self, model: JAXModel, petab_problem: petab.Problem): self._measurements = self._get_measurements(scs) self.parameters = self._get_nominal_parameter_values() + def save(self, directory: Path): + """ + Save the problem to a directory. + + :param directory: + Directory to save the problem to. + """ + self._petab_problem.to_files( + prefix_path=directory, + model_file="model", + condition_file="conditions.tsv", + measurement_file="measurements.tsv", + parameter_file="parameters.tsv", + observable_file="observables.tsv", + yaml_file="problem.yaml", + ) + shutil.copy(self.model.jax_py_file, directory / "jax_py_file.py") + with open(directory / "parameters.pkl", "wb") as f: + eqx.tree_serialise_leaves(f, self) + + @classmethod + def load(cls, directory: Path): + """ + Load a problem from a directory. + + :param directory: + Directory to load the problem from. + + :return: + Loaded problem instance. + """ + petab_problem = petab.Problem.from_yaml( + directory / "problem.yaml", + ) + model = _module_from_path("jax", directory / "jax_py_file.py").Model() + problem = cls(model, petab_problem) + with open(directory / "parameters.pkl", "rb") as f: + return eqx.tree_deserialise_leaves(f, problem) + def _get_parameter_mappings( self, simulation_conditions: pd.DataFrame ) -> dict[str, ParameterMappingForCondition]: diff --git a/python/tests/test_jax.py b/python/tests/test_jax.py index 3254667c50..30e205ca26 100644 --- a/python/tests/test_jax.py +++ b/python/tests/test_jax.py @@ -1,10 +1,12 @@ import pytest import amici +from pathlib import Path pytest.importorskip("jax") import amici.jax import jax.numpy as jnp +import jax.random as jr import jax import diffrax import numpy as np @@ -12,7 +14,10 @@ from amici.pysb_import import pysb2amici from amici.testing import TemporaryDirectoryWinSafe, skip_on_valgrind +from amici.petab.petab_import import import_petab_problem +from amici.jax import JAXProblem from numpy.testing import assert_allclose +from test_petab_objective import lotka_volterra # noqa: F401 pysb = pytest.importorskip("pysb") @@ -222,3 +227,28 @@ def check_fields_jax( rtol=1e-5, err_msg=f"field {field} does not match", ) + + +@skip_on_valgrind +def test_serialisation(lotka_volterra): # noqa: F811 + petab_problem = lotka_volterra + with TemporaryDirectoryWinSafe( + prefix=petab_problem.model.model_id + ) as model_dir: + jax_model = import_petab_problem( + petab_problem, jax=True, model_output_dir=model_dir + ) + jax_problem = JAXProblem(jax_model, petab_problem) + # change parameters to random values to test serialisation + jax_problem.update_parameters( + jax_problem.parameters + + jr.normal(jr.PRNGKey(0), jax_problem.parameters.shape) + ) + + with TemporaryDirectoryWinSafe() as outdir: + outdir = Path(outdir) + jax_problem.save(outdir) + jax_problem_loaded = JAXProblem.load(outdir) + assert_allclose( + jax_problem.parameters, jax_problem_loaded.parameters + ) diff --git a/tests/benchmark-models/test_petab_benchmark.py b/tests/benchmark-models/test_petab_benchmark.py index 7a0afc6832..2c56089409 100644 --- a/tests/benchmark-models/test_petab_benchmark.py +++ b/tests/benchmark-models/test_petab_benchmark.py @@ -338,12 +338,6 @@ def test_jax_llh(benchmark_problem): jax=True, ) jax_problem = JAXProblem(jax_model, petab_problem) - simulation_conditions = ( - petab_problem.get_simulation_conditions_from_measurement_df() - ) - simulation_conditions = tuple( - tuple(row) for _, row in simulation_conditions.iterrows() - ) if problem_parameters: jax_problem = eqx.tree_at( lambda x: x.parameters, @@ -355,11 +349,9 @@ def test_jax_llh(benchmark_problem): if problem_id in problems_for_gradient_check_jax: (llh_jax, _), sllh_jax = eqx.filter_jit( eqx.filter_value_and_grad(run_simulations, has_aux=True) - )(jax_problem, simulation_conditions) + )(jax_problem) else: - llh_jax, _ = beartype(eqx.filter_jit(run_simulations))( - jax_problem, simulation_conditions - ) + llh_jax, _ = beartype(eqx.filter_jit(run_simulations))(jax_problem) np.testing.assert_allclose( llh_jax,