Skip to content

Commit

Permalink
add jax serialisation (#2608)
Browse files Browse the repository at this point in the history
* add jax serialisation

* doc

* bad ruff

* bad ruff
  • Loading branch information
FFroehlich authored Dec 2, 2024
1 parent b9a3f1a commit 1505d90
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 20 deletions.
14 changes: 6 additions & 8 deletions python/sdist/amici/jax.template.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# ruff: noqa: F401, F821, F841
import jax.numpy as jnp
from interpax import interp1d
from pathlib import Path

from amici.jax.model import JAXModel

Expand All @@ -8,10 +10,10 @@ class JAXModel_TPL_MODEL_NAME(JAXModel):
api_version = TPL_MODEL_API_VERSION

def __init__(self):
self.jax_py_file = Path(__file__).resolve()
super().__init__()

def _xdot(self, t, x, args):

pk, tcl = args

TPL_X_SYMS = x
Expand All @@ -24,7 +26,6 @@ def _xdot(self, t, x, args):
return TPL_XDOT_RET

def _w(self, t, x, pk, tcl):

TPL_X_SYMS = x
TPL_PK_SYMS = pk
TPL_TCL_SYMS = tcl
Expand All @@ -34,23 +35,20 @@ def _w(self, t, x, pk, tcl):
return TPL_W_RET

def _x0(self, pk):

TPL_PK_SYMS = pk

TPL_X0_EQ

return TPL_X0_RET

def _x_solver(self, x):

TPL_X_RDATA_SYMS = x

TPL_X_SOLVER_EQ

return TPL_X_SOLVER_RET

def _x_rdata(self, x, tcl):

TPL_X_SYMS = x
TPL_TCL_SYMS = tcl

Expand All @@ -59,7 +57,6 @@ def _x_rdata(self, x, tcl):
return TPL_X_RDATA_RET

def _tcl(self, x, pk):

TPL_X_RDATA_SYMS = x
TPL_PK_SYMS = pk

Expand All @@ -68,7 +65,6 @@ def _tcl(self, x, pk):
return TPL_TOTAL_CL_RET

def _y(self, t, x, pk, tcl):

TPL_X_SYMS = x
TPL_PK_SYMS = pk
TPL_W_SYMS = self._w(t, x, pk, tcl)
Expand All @@ -86,7 +82,6 @@ def _sigmay(self, y, pk):

return TPL_SIGMAY_RET


def _nllh(self, t, x, pk, tcl, my, iy):
y = self._y(t, x, pk, tcl)
TPL_Y_SYMS = y
Expand All @@ -107,3 +102,6 @@ def state_ids(self):
@property
def parameter_ids(self):
return TPL_PK_IDS


Model = JAXModel_TPL_MODEL_NAME
4 changes: 3 additions & 1 deletion python/sdist/amici/jax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# ruff: noqa: F821 F722

from abc import abstractmethod
from pathlib import Path

import diffrax
import equinox as eqx
Expand All @@ -18,8 +19,9 @@ class JAXModel(eqx.Module):
classes inheriting from JAXModel.
"""

MODEL_API_VERSION = "0.0.1"
MODEL_API_VERSION = "0.0.2"
api_version: str
jax_py_file: Path

def __init__(self):
if self.api_version != self.MODEL_API_VERSION:
Expand Down
43 changes: 42 additions & 1 deletion python/sdist/amici/jax/petab.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""PEtab wrappers for JAX models.""" ""

import shutil
from numbers import Number
from collections.abc import Iterable
from pathlib import Path

import diffrax
import equinox as eqx
Expand All @@ -12,6 +13,7 @@
import pandas as pd
import petab.v1 as petab

from amici import _module_from_path
from amici.petab.parameter_mapping import (
ParameterMappingForCondition,
create_parameter_mapping,
Expand Down Expand Up @@ -84,6 +86,45 @@ def __init__(self, model: JAXModel, petab_problem: petab.Problem):
self._measurements = self._get_measurements(scs)
self.parameters = self._get_nominal_parameter_values()

def save(self, directory: Path):
"""
Save the problem to a directory.
:param directory:
Directory to save the problem to.
"""
self._petab_problem.to_files(
prefix_path=directory,
model_file="model",
condition_file="conditions.tsv",
measurement_file="measurements.tsv",
parameter_file="parameters.tsv",
observable_file="observables.tsv",
yaml_file="problem.yaml",
)
shutil.copy(self.model.jax_py_file, directory / "jax_py_file.py")
with open(directory / "parameters.pkl", "wb") as f:
eqx.tree_serialise_leaves(f, self)

@classmethod
def load(cls, directory: Path):
"""
Load a problem from a directory.
:param directory:
Directory to load the problem from.
:return:
Loaded problem instance.
"""
petab_problem = petab.Problem.from_yaml(
directory / "problem.yaml",
)
model = _module_from_path("jax", directory / "jax_py_file.py").Model()
problem = cls(model, petab_problem)
with open(directory / "parameters.pkl", "rb") as f:
return eqx.tree_deserialise_leaves(f, problem)

def _get_parameter_mappings(
self, simulation_conditions: pd.DataFrame
) -> dict[str, ParameterMappingForCondition]:
Expand Down
30 changes: 30 additions & 0 deletions python/tests/test_jax.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
import pytest
import amici
from pathlib import Path

pytest.importorskip("jax")
import amici.jax

import jax.numpy as jnp
import jax.random as jr
import jax
import diffrax
import numpy as np
from beartype import beartype

from amici.pysb_import import pysb2amici
from amici.testing import TemporaryDirectoryWinSafe, skip_on_valgrind
from amici.petab.petab_import import import_petab_problem
from amici.jax import JAXProblem
from numpy.testing import assert_allclose
from test_petab_objective import lotka_volterra # noqa: F401

pysb = pytest.importorskip("pysb")

Expand Down Expand Up @@ -222,3 +227,28 @@ def check_fields_jax(
rtol=1e-5,
err_msg=f"field {field} does not match",
)


@skip_on_valgrind
def test_serialisation(lotka_volterra): # noqa: F811
petab_problem = lotka_volterra
with TemporaryDirectoryWinSafe(
prefix=petab_problem.model.model_id
) as model_dir:
jax_model = import_petab_problem(
petab_problem, jax=True, model_output_dir=model_dir
)
jax_problem = JAXProblem(jax_model, petab_problem)
# change parameters to random values to test serialisation
jax_problem.update_parameters(
jax_problem.parameters
+ jr.normal(jr.PRNGKey(0), jax_problem.parameters.shape)
)

with TemporaryDirectoryWinSafe() as outdir:
outdir = Path(outdir)
jax_problem.save(outdir)
jax_problem_loaded = JAXProblem.load(outdir)
assert_allclose(
jax_problem.parameters, jax_problem_loaded.parameters
)
12 changes: 2 additions & 10 deletions tests/benchmark-models/test_petab_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,12 +338,6 @@ def test_jax_llh(benchmark_problem):
jax=True,
)
jax_problem = JAXProblem(jax_model, petab_problem)
simulation_conditions = (
petab_problem.get_simulation_conditions_from_measurement_df()
)
simulation_conditions = tuple(
tuple(row) for _, row in simulation_conditions.iterrows()
)
if problem_parameters:
jax_problem = eqx.tree_at(
lambda x: x.parameters,
Expand All @@ -355,11 +349,9 @@ def test_jax_llh(benchmark_problem):
if problem_id in problems_for_gradient_check_jax:
(llh_jax, _), sllh_jax = eqx.filter_jit(
eqx.filter_value_and_grad(run_simulations, has_aux=True)
)(jax_problem, simulation_conditions)
)(jax_problem)
else:
llh_jax, _ = beartype(eqx.filter_jit(run_simulations))(
jax_problem, simulation_conditions
)
llh_jax, _ = beartype(eqx.filter_jit(run_simulations))(jax_problem)

np.testing.assert_allclose(
llh_jax,
Expand Down

0 comments on commit 1505d90

Please sign in to comment.