Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate pre-equilibration and dynamic simulation in jax #2617

Merged
merged 15 commits into from
Dec 5, 2024
11 changes: 4 additions & 7 deletions python/examples/example_jax_petab/ExampleJaxPEtab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@
" results (dict): Simulation results from run_simulations.\n",
" \"\"\"\n",
" # Extract the simulation results for the specific condition\n",
" sim_results = results[simulation_condition][1]\n",
" sim_results = results[simulation_condition]\n",
"\n",
" # Create a new figure for the state trajectories\n",
" plt.figure(figsize=(8, 6))\n",
Expand Down Expand Up @@ -357,27 +357,25 @@
"simulation_condition = (\"model1_data1\",)\n",
"\n",
"# Load condition-specific data\n",
"ts_preeq, ts_dyn, ts_posteq, my, iys = jax_problem._measurements[\n",
"ts_init, ts_dyn, ts_posteq, my, iys = jax_problem._measurements[\n",
" simulation_condition\n",
"]\n",
"\n",
"# Load parameters for the specified condition\n",
"p = jax_problem.load_parameters(simulation_condition[0])\n",
"# Disable preequilibration\n",
"p_preeq = jnp.array([])\n",
"\n",
"\n",
"# Define a function to compute the gradient with respect to dynamic timepoints\n",
"@eqx.filter_jacfwd\n",
"def grad_ts_dyn(tt):\n",
" return jax_problem.model.simulate_condition(\n",
" p=p,\n",
" p_preeq=p_preeq,\n",
" ts_preeq=ts_preeq,\n",
" ts_init=ts_init,\n",
" ts_dyn=tt,\n",
" ts_posteq=ts_posteq,\n",
" my=jnp.array(my),\n",
" iys=jnp.array(iys),\n",
" x_preeq=jnp.array([]),\n",
" solver=diffrax.Kvaerno5(),\n",
" controller=diffrax.PIDController(atol=1e-8, rtol=1e-8),\n",
" max_steps=2**10,\n",
Expand Down Expand Up @@ -489,7 +487,6 @@
"amici_model = import_petab_problem(\n",
" petab_problem,\n",
" verbose=False,\n",
" compile_=True,\n",
" jax=False, # load the amici model this time\n",
")\n",
"\n",
Expand Down
64 changes: 42 additions & 22 deletions python/sdist/amici/jax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,12 +427,12 @@
def simulate_condition(
self,
p: jt.Float[jt.Array, "np"],
p_preeq: jt.Float[jt.Array, "*np"],
ts_preeq: jt.Float[jt.Array, "nt_preeq"],
ts_init: jt.Float[jt.Array, "nt_preeq"],
ts_dyn: jt.Float[jt.Array, "nt_dyn"],
ts_posteq: jt.Float[jt.Array, "nt_posteq"],
my: jt.Float[jt.Array, "nt"],
iys: jt.Int[jt.Array, "nt"],
x_preeq: jt.Float[jt.Array, "nx"],
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
adjoint: diffrax.AbstractAdjoint,
Expand All @@ -444,12 +444,9 @@

:param p:
parameters for simulation ordered according to ids in :ivar parameter_ids:
:param p_preeq:
parameters for pre-equilibration ordered according to ids in :ivar parameter_ids:. May be empty to
disable pre-equilibration.
:param ts_preeq:
time points for pre-equilibration. Usually valued 0.0, but needs to be shaped according to
the number of observables that are evaluated after pre-equilibration.
:param ts_init:
time points that do not require simulation. Usually valued 0.0, but needs to be shaped according to
the number of observables that are evaluated before dynamic simulation.
:param ts_dyn:
time points for dynamic simulation. Usually valued > 0.0 and sorted in monotonically increasing order.
Duplicate time points are allowed to facilitate the evaluation of multiple observables at specific time
Expand Down Expand Up @@ -486,24 +483,16 @@
output according to `ret` and statistics
"""
# Pre-equilibration
if p_preeq.shape[0] > 0:
x0 = self._x0(p_preeq)
tcl = self._tcl(x0, p_preeq)
current_x = self._x_solver(x0)
current_x, stats_preeq = self._eq(
p_preeq, tcl, current_x, solver, controller, max_steps
)
if x_preeq.shape[0] > 0:
current_x = self._x_solver(x_preeq)

Check warning on line 487 in python/sdist/amici/jax/model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L487

Added line #L487 was not covered by tests
# update tcl with new parameters
tcl = self._tcl(self._x_rdata(current_x, tcl), p)
tcl = self._tcl(x_preeq, p)

Check warning on line 489 in python/sdist/amici/jax/model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L489

Added line #L489 was not covered by tests
else:
x0 = self._x0(p)
current_x = self._x_solver(x0)
stats_preeq = None

tcl = self._tcl(x0, p)
x_preq = jnp.repeat(
current_x.reshape(1, -1), ts_preeq.shape[0], axis=0
)
x_preq = jnp.repeat(current_x.reshape(1, -1), ts_init.shape[0], axis=0)

# Dynamic simulation
if ts_dyn.shape[0] > 0:
Expand Down Expand Up @@ -536,7 +525,7 @@
current_x.reshape(1, -1), ts_posteq.shape[0], axis=0
)

ts = jnp.concatenate((ts_preeq, ts_dyn, ts_posteq), axis=0)
ts = jnp.concatenate((ts_init, ts_dyn, ts_posteq), axis=0)
x = jnp.concatenate((x_preq, x_dyn, x_posteq), axis=0)

nllhs = self._nllhs(ts, x, p, tcl, my, iys)
Expand All @@ -555,11 +544,42 @@
}[ret], dict(
ts=ts,
x=x,
stats_preeq=stats_preeq,
stats_dyn=stats_dyn,
stats_posteq=stats_posteq,
)

@eqx.filter_jit
def preequilibrate_condition(
self,
p: jt.Float[jt.Array, "np"],
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
max_steps: int | jnp.int_,
) -> tuple[jt.Float[jt.Array, "nx"], dict]:
r"""
Simulate a condition.

:param p:
parameters for simulation ordered according to ids in :ivar parameter_ids:
:param solver:
ODE solver
:param controller:
step size controller
:param max_steps:
maximum number of solver steps
:return:
pre-equilibrated state variables and statistics
"""
# Pre-equilibration
x0 = self._x0(p)
tcl = self._tcl(x0, p)
current_x = self._x_solver(x0)
current_x, stats_preeq = self._eq(

Check warning on line 577 in python/sdist/amici/jax/model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L574-L577

Added lines #L574 - L577 were not covered by tests
p, tcl, current_x, solver, controller, max_steps
)

return self._x_rdata(current_x, tcl), dict(stats_preeq=stats_preeq)

Check warning on line 581 in python/sdist/amici/jax/model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L581

Added line #L581 was not covered by tests


def safe_log(x: jnp.float_) -> jnp.float_:
"""
Expand Down
5 changes: 2 additions & 3 deletions python/sdist/amici/jax/ode_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,12 +234,10 @@ def _generate_jax_code(self) -> None:
"MODEL_API_VERSION": f"'{JAXModel.MODEL_API_VERSION}'",
},
}
outdir = self.model_path / (self.model_name + "_jax")
outdir.mkdir(parents=True, exist_ok=True)

apply_template(
Path(amiciModulePath) / "jax" / "jax.template.py",
outdir / "__init__.py",
self.model_path / "__init__.py",
tpl_data,
)

Expand All @@ -258,6 +256,7 @@ def set_paths(self, output_dir: str | Path | None = None) -> None:
output_dir = Path(os.getcwd()) / f"amici-{self.model_name}"

self.model_path = Path(output_dir).resolve()
self.model_path.mkdir(parents=True, exist_ok=True)

def set_name(self, model_name: str) -> None:
"""
Expand Down
67 changes: 54 additions & 13 deletions python/sdist/amici/jax/petab.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@
def _get_measurements(
self, simulation_conditions: pd.DataFrame
) -> dict[
tuple[str],
tuple[str, ...],
tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray],
]:
"""
Expand Down Expand Up @@ -307,49 +307,75 @@
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
max_steps: jnp.int_,
x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]), # noqa: F821, F722
) -> tuple[jnp.float_, dict]:
"""
Run a simulation for a given simulation condition.

:param simulation_condition:
Tuple of simulation conditions to run the simulation for. can be a single string (simulation only) or a
tuple of strings (pre-equilibration followed by simulation).
Simulation condition to run simulation for.
:param solver:
ODE solver to use for simulation
:param controller:
Step size controller to use for simulation
:param max_steps:
Maximum number of steps to take during simulation
:param x_preeq:
Pre-equilibration state if available
:return:
Tuple of log-likelihood and simulation statistics
"""
ts_preeq, ts_dyn, ts_posteq, my, iys = self._measurements[
simulation_condition
]
p = self.load_parameters(simulation_condition[0])
p_preeq = (
self.load_parameters(simulation_condition[1])
if len(simulation_condition) > 1
else jnp.array([])
)
return self.model.simulate_condition(
p=p,
p_preeq=p_preeq,
ts_preeq=jax.lax.stop_gradient(jnp.array(ts_preeq)),
ts_init=jax.lax.stop_gradient(jnp.array(ts_preeq)),
ts_dyn=jax.lax.stop_gradient(jnp.array(ts_dyn)),
ts_posteq=jax.lax.stop_gradient(jnp.array(ts_posteq)),
my=jax.lax.stop_gradient(jnp.array(my)),
iys=jax.lax.stop_gradient(jnp.array(iys)),
x_preeq=x_preeq,
solver=solver,
controller=controller,
max_steps=max_steps,
adjoint=diffrax.RecursiveCheckpointAdjoint(),
)

def run_preequilibration(
self,
simulation_condition: str,
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
max_steps: jnp.int_,
) -> tuple[jt.Float[jt.Array, "nx"], dict]: # noqa: F821
"""
Run a pre-equilibration simulation for a given simulation condition.

:param simulation_condition:
Simulation condition to run simulation for.
:param solver:
ODE solver to use for simulation
:param controller:
Step size controller to use for simulation
:param max_steps:
Maximum number of steps to take during simulation
:return:
Pre-equilibration state
"""
p = self.load_parameters(simulation_condition)
return self.model.preequilibrate_condition(

Check warning on line 368 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L367-L368

Added lines #L367 - L368 were not covered by tests
p=p,
solver=solver,
controller=controller,
max_steps=max_steps,
)


def run_simulations(
problem: JAXProblem,
simulation_conditions: Iterable[tuple] | None = None,
simulation_conditions: Iterable[tuple[str, ...]] | None = None,
solver: diffrax.AbstractSolver = diffrax.Kvaerno5(),
controller: diffrax.AbstractStepSizeController = diffrax.PIDController(
rtol=1e-8,
Expand Down Expand Up @@ -379,8 +405,23 @@
if simulation_conditions is None:
simulation_conditions = problem.get_all_simulation_conditions()

preeqs = {

Check warning on line 408 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L408

Added line #L408 was not covered by tests
sc: problem.run_preequilibration(sc, solver, controller, max_steps)
# only run preequilibration once per condition
for sc in {sc[1] for sc in simulation_conditions if len(sc) > 1}
}

results = {
sc: problem.run_simulation(sc, solver, controller, max_steps)
sc: problem.run_simulation(
sc,
solver,
controller,
max_steps,
preeqs.get(sc[1])[0] if len(sc) > 1 else jnp.array([]),
)
for sc in simulation_conditions
}
return sum(llh for llh, _ in results.values()), results
return sum(llh for llh, _ in results.values()), {

Check warning on line 424 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L424

Added line #L424 was not covered by tests
sc: res[1] | preeqs[sc[1]][1] if len(sc) > 1 else res[1]
for sc, res in results.items()
}
24 changes: 22 additions & 2 deletions python/sdist/amici/petab/import_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,9 @@ def _can_import_model(
Check whether a module of that name can already be imported.
"""
# try to import (in particular checks version)
suffix = "_jax" if jax else ""
try:
model_module = amici.import_model_module(
model_name + suffix, model_output_dir
*_get_package_name_and_path(model_name, model_output_dir, jax)
)
except ModuleNotFoundError:
return False
Expand Down Expand Up @@ -271,3 +270,24 @@ def check_model(
"the current model might also resolve this. Parameters: "
f"{amici_ids_free_required.difference(amici_ids_free)}"
)


def _get_package_name_and_path(
model_name: str, model_output_dir: str | Path, jax: bool = False
) -> tuple[str, Path]:
"""
Get the package name and path for the generated model module.

:param model_name:
Name of the model
:param model_output_dir:
Target directory for the generated model module
:param jax:
Whether to generate the paths for a JAX or CPP model
:return:
"""
if jax:
outdir = Path(model_output_dir)
return outdir.stem, outdir.parent
else:
return model_name, Path(model_output_dir)
14 changes: 9 additions & 5 deletions python/sdist/amici/petab/petab_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
from petab.v1.models import MODEL_TYPE_PYSB, MODEL_TYPE_SBML

from ..logging import get_logger
from .import_helpers import _can_import_model, _create_model_name, check_model
from .import_helpers import (
_can_import_model,
_create_model_name,
check_model,
_get_package_name_and_path,
)
from .sbml_import import import_model_sbml

try:
Expand Down Expand Up @@ -114,7 +119,7 @@ def import_petab_problem(
from .sbml_import import _create_model_output_dir_name

model_output_dir = _create_model_output_dir_name(
petab_problem.sbml_model, model_name
petab_problem.sbml_model, model_name, jax=jax
)
else:
model_output_dir = os.path.abspath(model_output_dir)
Expand All @@ -136,7 +141,7 @@ def import_petab_problem(
)

# remove folder if exists
if os.path.exists(model_output_dir):
if not jax and os.path.exists(model_output_dir):
shutil.rmtree(model_output_dir)

logger.info(f"Compiling model {model_name} to {model_output_dir}.")
Expand All @@ -160,9 +165,8 @@ def import_petab_problem(
)

# import model
suffix = "_jax" if jax else ""
model_module = amici.import_model_module(
model_name + suffix, model_output_dir
*_get_package_name_and_path(model_name, model_output_dir, jax=jax)
)

if jax:
Expand Down
Loading
Loading