Skip to content

Commit

Permalink
add example
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Nov 19, 2024
1 parent d9ae05e commit f7c2c10
Show file tree
Hide file tree
Showing 6 changed files with 1,210 additions and 18 deletions.
1 change: 1 addition & 0 deletions documentation/ExampleJaxPEtab.ipynb
1 change: 1 addition & 0 deletions documentation/python_examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@ Various example notebooks.
example_errors.ipynb
example_large_models/example_performance_optimization.ipynb
ExampleJax.ipynb
ExampleJaxPEtab.ipynb
ExampleSplines.ipynb
ExampleSplinesSwameye2003.ipynb
1,171 changes: 1,171 additions & 0 deletions python/examples/example_jax_petab/ExampleJaxPEtab.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion python/sdist/amici/jax.template.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def _sigmay(self, y, pk):
return TPL_SIGMAY_RET


def _llh(self, t, x, pk, tcl, my, iy):
def _nllh(self, t, x, pk, tcl, my, iy):
y = self._y(t, x, pk, tcl)
TPL_Y_SYMS = y
TPL_SIGMAY_SYMS = self._sigmay(y, pk)
Expand Down
29 changes: 15 additions & 14 deletions python/sdist/amici/jax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def _sigmay(
...

@abstractmethod
def _llh(
def _nllh(
self,
t: jt.Float[jt.Scalar, ""],
x: jt.Float[jt.Array, "nxs"],
Expand All @@ -172,7 +172,7 @@ def _llh(
iy: jt.Int[jt.Array, ""],
) -> jt.Float[jt.Scalar, ""]:
"""
Compute the log-likelihood of the observable for the specified observable index.
Compute the negative log-likelihood of the observable for the specified observable index.
:param t:
time point
Expand Down Expand Up @@ -326,7 +326,7 @@ def _x_rdatas(
"""
return jax.vmap(self._x_rdata, in_axes=(0, None))(x, tcl)

def _llhs(
def _nllhs(
self,
ts: jt.Float[jt.Array, "nt nx"],
xs: jt.Float[jt.Array, "nt nxs"],
Expand All @@ -336,7 +336,7 @@ def _llhs(
iys: jt.Int[jt.Array, "nt"],
) -> jt.Float[jt.Array, "nt"]:
"""
Compute the log-likelihood of the observables.
Compute the negative log-likelihood for each observable.
:param ts:
time points
Expand All @@ -351,9 +351,9 @@ def _llhs(
:param iys:
observable indices
:return:
log-likelihood of the observables
negative log-likelihoods of the observables
"""
return jax.vmap(self._llh, in_axes=(0, 0, None, None, 0, 0))(
return jax.vmap(self._nllh, in_axes=(0, 0, None, None, 0, 0))(
ts, xs, p, tcl, mys, iys
)

Expand Down Expand Up @@ -431,8 +431,8 @@ def simulate_condition(
controller: diffrax.AbstractStepSizeController,
adjoint: diffrax.AbstractAdjoint,
max_steps: int | jnp.int_,
ret: str = "nllh",
):
ret: str = "llh",
) -> tuple[jt.Float[jt.Array, "nt *nx"] | jnp.float_, dict]:
r"""
Simulate a condition.
Expand Down Expand Up @@ -466,8 +466,8 @@ def simulate_condition(
maximum number of solver steps
:param ret:
which output to return. Valid values are
- `nllh`: negative log-likelihood (default)
- `llhs`: log-likelihoods at each time point
- `llh`: log-likelihood (default)
- `nllhs`: negative log-likelihood at each time point
- `x0`: full initial state vector (after pre-equilibration)
- `x0_solver`: reduced initial state vector (after pre-equilibration)
- `x`: full state vector
Expand Down Expand Up @@ -533,11 +533,11 @@ def simulate_condition(
ts = jnp.concatenate((ts_preeq, ts_dyn, ts_posteq), axis=0)
x = jnp.concatenate((x_preq, x_dyn, x_posteq), axis=0)

llhs = self._llhs(ts, x, p, tcl, my, iys)
nllh = -jnp.sum(llhs)
nllhs = self._nllhs(ts, x, p, tcl, my, iys)
llh = -jnp.sum(nllhs)
return {
"nllh": nllh,
"llhs": llhs,
"llh": llh,
"nllhs": nllhs,
"x": self._x_rdatas(x, tcl),
"x_solver": x,
"y": self._ys(ts, x, p, tcl, iys),
Expand All @@ -547,6 +547,7 @@ def simulate_condition(
"tcl": tcl,
"res": self._ys(ts, x, p, tcl, iys) - my,
}[ret], dict(
ts=ts,
x=x,
stats_preeq=stats_preeq,
stats_dyn=stats_dyn,
Expand Down
24 changes: 21 additions & 3 deletions python/sdist/amici/jax/petab.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class JAXProblem(eqx.Module):
model: JAXModel
_parameter_mappings: dict[str, ParameterMappingForCondition]
_measurements: dict[
tuple[str],
tuple[str, ...],
tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray],
]
_petab_problem: petab.Problem
Expand Down Expand Up @@ -156,6 +156,12 @@ def _get_measurements(
)
return measurements

def get_all_simulation_conditions(self) -> tuple[tuple[str, ...], ...]:
simulation_conditions = (
self._petab_problem.get_simulation_conditions_from_measurement_df()
)
return tuple(tuple(row) for _, row in simulation_conditions.iterrows())

def _get_nominal_parameter_values(self) -> jt.Float[jt.Array, "np"]:
"""
Get the nominal parameter values for the model based on the nominal values in the PEtab problem.
Expand Down Expand Up @@ -245,9 +251,18 @@ def load_parameters(
)
return self._unscale(p, pscale)

def update_parameters(self, p: jt.Float[jt.Array, "np"]) -> "JAXProblem":
"""
Update parameters for the model.
:param p:
New problem instance with updated parameters.
"""
return eqx.tree_at(lambda p: p.parameters, self, p)

def run_simulation(
self,
simulation_condition: tuple[str],
simulation_condition: tuple[str, ...],
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
max_steps: jnp.int_,
Expand Down Expand Up @@ -293,7 +308,7 @@ def run_simulation(

def run_simulations(
problem: JAXProblem,
simulation_conditions: Iterable[tuple],
simulation_conditions: Iterable[tuple] | None = None,
solver: diffrax.AbstractSolver = diffrax.Kvaerno5(),
controller: diffrax.AbstractStepSizeController = diffrax.PIDController(
rtol=1e-8,
Expand All @@ -320,6 +335,9 @@ def run_simulations(
:return:
Overall negative log-likelihood and condition specific results and statistics.
"""
if simulation_conditions is None:
simulation_conditions = problem.get_all_simulation_conditions()

results = {
sc: problem.run_simulation(sc, solver, controller, max_steps)
for sc in simulation_conditions
Expand Down

0 comments on commit f7c2c10

Please sign in to comment.