Skip to content

Commit

Permalink
review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Nov 18, 2024
1 parent 186805c commit 250f9dd
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 41 deletions.
1 change: 1 addition & 0 deletions documentation/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def install_doxygen():
"numpy": ("https://numpy.org/devdocs/", None),
"sympy": ("https://docs.sympy.org/latest/", None),
"python": ("https://docs.python.org/3", None),
"jax": ["https://jax.readthedocs.io/en/latest/", None],
}

# Add notebooks prolog with binder links
Expand Down
1 change: 1 addition & 0 deletions documentation/python_modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ AMICI Python API
amici.petab_objective
amici.petab_simulate
amici.import_utils
amici.jax
amici.de_export
amici.de_model
amici.de_model_components
Expand Down
49 changes: 30 additions & 19 deletions python/sdist/amici/de_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
TYPE_CHECKING,
Literal,
)
from itertools import chain

import sympy as sp

Expand Down Expand Up @@ -300,30 +301,38 @@ def jnp_array_str(array) -> str:

return f"jnp.array([{elems}])"

# replaces Heaviside variables with corresponding functions
subs_heaviside = dict(
zip(
self.model.sym("h"),
[sp.Heaviside(x) for x in self.model.eq("root")],
strict=True,
)
)
# replaces observables with a generic my variable
subs_observables = dict(
zip(
self.model.sym("my"),
[sp.Symbol("my")] * len(self.model.sym("my")),
strict=True,
)
)

tpl_data = {
# assign named variable using corresponding algebraic formula (function body)
**{
f"{eq_name.upper()}_EQ": "\n".join(
self._code_printer_jax._get_sym_lines(
(str(strip_pysb(s)) for s in self.model.sym(eq_name)),
self.model.eq(eq_name).subs(
dict(
zip(
list(self.model.sym("h"))
+ list(self.model.sym("my")),
[
sp.Heaviside(x)
for x in self.model.eq("root")
]
+ [sp.Symbol("my")]
* len(self.model.sym("my")),
)
)
{**subs_heaviside, **subs_observables}
),
indent,
)
)[indent:]
)[indent:] # remove indent for first line
for eq_name in eq_names
},
# create jax array from concatenation of named variables
**{
f"{eq_name.upper()}_RET": jnp_array_str(
strip_pysb(s) for s in self.model.sym(eq_name)
Expand All @@ -332,6 +341,7 @@ def jnp_array_str(array) -> str:
else "jnp.array([])"
for eq_name in eq_names
},
# assign named variables from a jax array
**{
f"{sym_name.upper()}_SYMS": "".join(
str(strip_pysb(s)) + ", " for s in self.model.sym(sym_name)
Expand All @@ -340,6 +350,7 @@ def jnp_array_str(array) -> str:
else "_"
for sym_name in sym_names
},
# tuple of variable names (ids as they are unique)
**{
f"{sym_name.upper()}_IDS": "".join(
f'"{strip_pysb(s)}", ' for s in self.model.sym(sym_name)
Expand All @@ -349,19 +360,19 @@ def jnp_array_str(array) -> str:
for sym_name in ("p", "k", "y", "x")
},
**{
# in jax model we do not need to distinguish between p (parameters) and
# k (fixed parameters) so we use a single variable combining both
"PK_SYMS": "".join(
str(strip_pysb(s)) + ", "
for s in list(self.model.sym("p"))
+ list(self.model.sym("k"))
for s in chain(self.model.sym("p"), self.model.sym("k"))
),
"PK_IDS": "".join(
f'"{strip_pysb(s)}", '
for s in list(self.model.sym("p"))
+ list(self.model.sym("k"))
for s in chain(self.model.sym("p"), self.model.sym("k"))
),
},
**{
"MODEL_NAME": self.model_name,
# keep track of the API version that the model was generated with so we
# can flag conflicts in the future
"MODEL_API_VERSION": f"'{JAXModel.MODEL_API_VERSION}'",
},
}
Expand Down
1 change: 1 addition & 0 deletions python/sdist/amici/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Interface to facilitate AMICI generated models using JAX"""
12 changes: 7 additions & 5 deletions python/sdist/amici/jax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def _llh(
) -> jt.Float[jt.Scalar, ""]:
"""
Compute the log-likelihood of the observable for the specified observable index.
:param t:
time point
:param x:
Expand Down Expand Up @@ -430,10 +431,11 @@ def simulate_condition(
controller: diffrax.AbstractStepSizeController,
adjoint: diffrax.AbstractAdjoint,
max_steps: int | jnp.int_,
ret: str = "llh",
ret: str = "nllh",
):
r"""
Simulate a condition.
:param p:
parameters for simulation ordered according to ids in :ivar parameter_ids:
:param p_preeq:
Expand Down Expand Up @@ -464,8 +466,8 @@ def simulate_condition(
maximum number of solver steps
:param ret:
which output to return. Valid values are
- `llh`: negative log-likelihood (default)
- `llhs`: negative log-likelihoods at each time point
- `nllh`: negative log-likelihood (default)
- `llhs`: log-likelihoods at each time point
- `x0`: full initial state vector (after pre-equilibration)
- `x0_solver`: reduced initial state vector (after pre-equilibration)
- `x`: full state vector
Expand Down Expand Up @@ -532,9 +534,9 @@ def simulate_condition(
x = jnp.concatenate((x_preq, x_dyn, x_posteq), axis=0)

llhs = self._llhs(ts, x, p, tcl, my, iys)
llh = -jnp.sum(llhs)
nllh = -jnp.sum(llhs)
return {
"llh": llh,
"nllh": nllh,
"llhs": llhs,
"x": self._x_rdatas(x, tcl),
"x_solver": x,
Expand Down
16 changes: 6 additions & 10 deletions python/sdist/amici/jax/petab.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _get_parameter_mappings(
for sim_var, value in mapping.map_sim_var.items():
if isinstance(value, Number) and not np.isfinite(value):
mapping.map_sim_var[sim_var] = 1.0
return dict(zip(scs, mappings))
return dict(zip(scs, mappings, strict=True))

def _get_measurements(
self, simulation_conditions: pd.DataFrame
Expand All @@ -117,7 +117,7 @@ def _get_measurements(
tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray],
]:
"""
Set measurements for the model based on the provided simulation conditions.
Get measurements for the model based on the provided simulation conditions.
:param simulation_conditions:
Simulation conditions to create parameter mappings for. Same format as returned by
Expand Down Expand Up @@ -156,17 +156,13 @@ def _get_measurements(
)
return measurements

def _get_nominal_parameter_values(self) -> jnp.ndarray:
def _get_nominal_parameter_values(self) -> jt.Float[jt.Array, "np"]:
"""
Set the nominal parameter values for the model based on the nominal values in the PEtab problem.
Get the nominal parameter values for the model based on the nominal values in the PEtab problem.
:return:
JAXModel instance with parameter values set to the nominal values.
jax array with nominal parameter values
"""
if self._petab_problem is None:
raise ValueError(
"PEtab problem not set, cannot set nominal values."
)
return jnp.array(
[
petab.scale(
Expand Down Expand Up @@ -306,7 +302,7 @@ def run_simulations(
icoeff=0.3,
dcoeff=0.0,
),
max_steps: int = 2**14,
max_steps: int = 2**10,
):
"""
Run simulations for a problem.
Expand Down
16 changes: 9 additions & 7 deletions tests/benchmark-models/test_petab_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,10 @@
from pathlib import Path
import fiddy
import amici
import equinox as eqx
import jax.numpy as jnp
import numpy as np
import pandas as pd
import petab.v1 as petab
import pytest
import jax
from amici.petab.petab_import import import_petab_problem
import benchmark_models_petab
from collections import defaultdict
Expand All @@ -37,11 +34,8 @@
rdatas_to_measurement_df,
simulate_petab,
)
from amici.jax.petab import run_simulations, JAXProblem
from petab.v1.visualize import plot_problem
from beartype import beartype

jax.config.update("jax_enable_x64", True)
from petab.v1.visualize import plot_problem


# Enable various debug output
Expand Down Expand Up @@ -267,6 +261,14 @@ def benchmark_problem(request):
"ignore:Adjoint sensitivity analysis for models with discontinuous ",
)
def test_jax_llh(benchmark_problem):
import jax
import equinox as eqx
import jax.numpy as jnp
from amici.jax.petab import run_simulations, JAXProblem

jax.config.update("jax_enable_x64", True)
from beartype import beartype

problem_id, petab_problem, amici_model = benchmark_problem

if problem_id in (
Expand Down

0 comments on commit 250f9dd

Please sign in to comment.