Skip to content

Commit

Permalink
refactor & simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Dec 7, 2024
1 parent 4eb18d9 commit 61ab683
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 139 deletions.
98 changes: 60 additions & 38 deletions python/sdist/amici/jax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from abc import abstractmethod
from pathlib import Path
import enum

import diffrax
import equinox as eqx
Expand All @@ -12,6 +13,20 @@
import jaxtyping as jt


class ReturnValue(enum.Enum):
llh = "log-likelihood"
nllhs = "pointwise negative log-likelihood"
x0 = "full initial state vector"
x0_solver = "reduced initial state vector"
x = "full state vector"
x_solver = "reduced state vector"
y = "observables"
sigmay = "standard deviations of the observables"
tcl = "total values for conservation laws"
res = "residuals"
chi2 = "sum(((observed - simulated) / sigma ) ** 2)"


class JAXModel(eqx.Module):
"""
JAXModel provides an abstract base class for a JAX-based implementation of an AMICI model. The class implements
Expand Down Expand Up @@ -440,7 +455,7 @@ def simulate_condition(
x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]),
mask_reinit: jt.Bool[jt.Array, "*nx"] = jnp.array([]),
x_reinit: jt.Float[jt.Array, "*nx"] = jnp.array([]),
ret: str = "llh",
ret: ReturnValue = ReturnValue.llh,
) -> tuple[jt.Float[jt.Array, "nt *nx"] | jnp.float_, dict]:
r"""
Simulate a condition.
Expand Down Expand Up @@ -478,18 +493,7 @@ def simulate_condition(
:param max_steps:
maximum number of solver steps
:param ret:
which output to return. Valid values are
- `llh`: log-likelihood (default)
- `nllhs`: negative log-likelihood at each time point
- `x0`: full initial state vector (after pre-equilibration)
- `x0_solver`: reduced initial state vector (after pre-equilibration)
- `x`: full state vector
- `x_solver`: reduced state vector
- `y`: observables
- `sigmay`: standard deviations of the observables
- `tcl`: total values for conservation laws (at final timepoint)
- `res`: residuals (observed - simulated)
- 'chi2': sum((observed - simulated) ** 2 / sigma ** 2)
which output to return. See :class:`ReturnValue` for available options.
:return:
output according to `ret` and statistics
"""
Expand Down Expand Up @@ -542,36 +546,54 @@ def simulate_condition(

nllhs = self._nllhs(ts, x, p, tcl, my, iys)
llh = -jnp.sum(nllhs)
obs_trafo = jax.vmap(
lambda y, iy_trafo: jnp.array(
[y, safe_log(y), safe_log(y) / jnp.log(10)]
)
.at[iy_trafo]
.get(),
)
ys_obj = obs_trafo(self._ys(ts, x, p, tcl, iys), iy_trafos)
m_obj = obs_trafo(my, iy_trafos)
return {
"llh": llh,
"nllhs": nllhs,
"x": self._x_rdatas(x, tcl),
"x_solver": x,
"y": self._ys(ts, x, p, tcl, iys),
"sigmay": self._sigmays(ts, x, p, tcl, iys),
"x0": self._x_rdata(x[0, :], tcl),
"x0_solver": x[0, :],
"tcl": tcl,
"res": self._ys(ts, x, p, tcl, iys) - my,
"chi2": jnp.sum(
jnp.square(ys_obj - m_obj)
/ jnp.square(self._sigmays(ts, x, p, tcl, iys))
),
}[ret], dict(

stats = dict(
ts=ts,
x=x,
llh=llh,
stats_dyn=stats_dyn,
stats_posteq=stats_posteq,
)
if ret == ReturnValue.llh:
output = llh
elif ret == ReturnValue.nllhs:
output = nllhs

Check warning on line 560 in python/sdist/amici/jax/model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L560

Added line #L560 was not covered by tests
elif ret == ReturnValue.x:
output = self._x_rdatas(x, tcl)

Check warning on line 562 in python/sdist/amici/jax/model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L562

Added line #L562 was not covered by tests
elif ret == ReturnValue.x_solver:
output = x

Check warning on line 564 in python/sdist/amici/jax/model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L564

Added line #L564 was not covered by tests
elif ret == ReturnValue.y:
output = self._ys(ts, x, p, tcl, iys)

Check warning on line 566 in python/sdist/amici/jax/model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L566

Added line #L566 was not covered by tests
elif ret == ReturnValue.sigmay:
output = self._sigmays(ts, x, p, tcl, iys)

Check warning on line 568 in python/sdist/amici/jax/model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L568

Added line #L568 was not covered by tests
elif ret == ReturnValue.x0:
output = self._x_rdata(x[0, :], tcl)

Check warning on line 570 in python/sdist/amici/jax/model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L570

Added line #L570 was not covered by tests
elif ret == ReturnValue.x0_solver:
output = x[0, :]

Check warning on line 572 in python/sdist/amici/jax/model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L572

Added line #L572 was not covered by tests
elif ret == ReturnValue.tcl:
output = tcl

Check warning on line 574 in python/sdist/amici/jax/model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L574

Added line #L574 was not covered by tests
elif ret in (ReturnValue.res, ReturnValue.chi2):
obs_trafo = jax.vmap(

Check warning on line 576 in python/sdist/amici/jax/model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L576

Added line #L576 was not covered by tests
lambda y, iy_trafo: jnp.array(
# needs to follow order in amici.jax.petab.SCALE_TO_INT
[y, safe_log(y), safe_log(y) / jnp.log(10)]
)
.at[iy_trafo]
.get(),
)
ys_obj = obs_trafo(self._ys(ts, x, p, tcl, iys), iy_trafos)
m_obj = obs_trafo(my, iy_trafos)
if ret == ReturnValue.chi2:
output = jnp.sum(

Check warning on line 587 in python/sdist/amici/jax/model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L584-L587

Added lines #L584 - L587 were not covered by tests
jnp.square(ys_obj - m_obj)
/ jnp.square(self._sigmays(ts, x, p, tcl, iys))
)
else:
output = ys_obj - m_obj

Check warning on line 592 in python/sdist/amici/jax/model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L592

Added line #L592 was not covered by tests
else:
raise NotImplementedError(f"Return value {ret} not implemented.")

return output, stats

@eqx.filter_jit
def preequilibrate_condition(
Expand Down
Loading

0 comments on commit 61ab683

Please sign in to comment.