Skip to content

Commit

Permalink
refactor & simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Nov 12, 2024
1 parent da02106 commit a46e65d
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 103 deletions.
207 changes: 107 additions & 100 deletions python/sdist/amici/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import jax.numpy as jnp
import numpy as np
import jax
import pandas as pd
import petab.v1 as petab

import amici
Expand All @@ -24,66 +25,34 @@


class JAXModel(eqx.Module):
_unscale_funs = {
amici.ParameterScaling.none: lambda x: x,
amici.ParameterScaling.ln: lambda x: jnp.exp(x),
amici.ParameterScaling.log10: lambda x: jnp.power(10, x),
}
solver: diffrax.AbstractSolver
controller: diffrax.AbstractStepSizeController
atol: float
rtol: float
pcoeff: float
icoeff: float
dcoeff: float
maxsteps: int
parameters: jnp.ndarray
parameter_mappings: dict[tuple[str], ParameterMappingForCondition]
term: diffrax.ODETerm
parameter_mappings: dict[tuple[str], ParameterMappingForCondition] | None
measurements: dict[tuple[str], pd.DataFrame] | None
petab_problem: petab.Problem | None

def __init__(self):
self.solver = diffrax.Kvaerno5()
self.atol: float = 1e-8
self.rtol: float = 1e-8
self.pcoeff: float = 0.4
self.icoeff: float = 0.3
self.dcoeff: float = 0.0
self.maxsteps: int = 2**14
self.controller = diffrax.PIDController(
rtol=self.rtol,
atol=self.atol,
pcoeff=self.pcoeff,
icoeff=self.icoeff,
dcoeff=self.dcoeff,
rtol=1e-8,
atol=1e-8,
pcoeff=0.4,
icoeff=0.3,
dcoeff=0.0,
)
self.term = diffrax.ODETerm(self.xdot)
self.petab_problem = None
self.parameter_mappings = None
self.measurements = None
self.parameters = jnp.array([])

def set_petab_problem(self, petab_problem: petab.Problem) -> "JAXModel":
"""
Set the PEtab problem for the model and updates parameters to the nominal values.
:param petab_problem:
Petab problem to set.
:return: JAXModel instance
"""

is_leaf = lambda x: x is None if self.petab_problem is None else None # noqa: E731
model = eqx.tree_at(
lambda x: x.petab_problem,
self,
petab_problem,
is_leaf=is_leaf,
)

simulation_conditions = (
petab_problem.get_simulation_conditions_from_measurement_df()
)

def _set_parameter_mappings(
self, simulation_conditions: pd.DataFrame
) -> "JAXModel":
mappings = create_parameter_mapping(

Check warning on line 54 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L54

Added line #L54 was not covered by tests
petab_problem=petab_problem,
petab_problem=self.petab_problem,
simulation_conditions=simulation_conditions,
scaled_parameters=False,
amici_model=self,
Expand All @@ -95,31 +64,81 @@ def set_petab_problem(self, petab_problem: petab.Problem) -> "JAXModel":
simulation_conditions.iterrows(), mappings
)
}

is_leaf = ( # noqa: E731

Check warning on line 68 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L68

Added line #L68 was not covered by tests
lambda x: x is None if self.parameter_mappings is None else None
)
model = eqx.tree_at(
return eqx.tree_at(

Check warning on line 71 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L71

Added line #L71 was not covered by tests
lambda x: x.parameter_mappings,
model,
self,
parameter_mappings,
is_leaf=is_leaf,
)

def _set_measurements(
self, simulation_conditions: pd.DataFrame
) -> "JAXModel":
measurements = dict()
for _, simulation_condition in simulation_conditions.iterrows():
measurements_df = self.petab_problem.measurement_df
for k, v in simulation_condition.items():
measurements_df = measurements_df.query(f"{k} == '{v}'")

Check warning on line 85 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L81-L85

Added lines #L81 - L85 were not covered by tests

ts = _get_timepoints_with_replicates(measurements_df)
my = _get_measurements_and_sigmas(

Check warning on line 88 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L87-L88

Added lines #L87 - L88 were not covered by tests
measurements_df, ts, self.observable_ids
)[0].flatten()
measurements[tuple(simulation_condition)] = np.array(ts), my
is_leaf = ( # noqa: E731

Check warning on line 92 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L91-L92

Added lines #L91 - L92 were not covered by tests
lambda x: x is None if self.measurements is None else None
)
return eqx.tree_at(

Check warning on line 95 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L95

Added line #L95 was not covered by tests
lambda x: x.measurements,
self,
measurements,
is_leaf=is_leaf,
)

def _set_nominal_parameter_values(self) -> "JAXModel":
nominal_values = jnp.array(

Check warning on line 103 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L103

Added line #L103 was not covered by tests
[
petab.scale(
model.petab_problem.parameter_df.loc[
self.petab_problem.parameter_df.loc[
pval, petab.NOMINAL_VALUE
],
model.petab_problem.parameter_df.loc[
self.petab_problem.parameter_df.loc[
pval, petab.PARAMETER_SCALE
],
)
for pval in model.petab_parameter_ids()
for pval in self.petab_parameter_ids()
]
)
return eqx.tree_at(lambda x: x.parameters, self, nominal_values)

Check warning on line 116 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L116

Added line #L116 was not covered by tests

return eqx.tree_at(lambda x: x.parameters, model, nominal_values)
def _set_petab_problem(self, petab_problem: petab.Problem) -> "JAXModel":
is_leaf = lambda x: x is None if self.petab_problem is None else None # noqa: E731
return eqx.tree_at(

Check warning on line 120 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L119-L120

Added lines #L119 - L120 were not covered by tests
lambda x: x.petab_problem,
self,
petab_problem,
is_leaf=is_leaf,
)

def set_petab_problem(self, petab_problem: petab.Problem) -> "JAXModel":
"""
Set the PEtab problem for the model and updates parameters to the nominal values.
:param petab_problem:
Petab problem to set.
:return: JAXModel instance
"""

model = self._set_petab_problem(petab_problem)
simulation_conditions = (

Check warning on line 136 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L135-L136

Added lines #L135 - L136 were not covered by tests
petab_problem.get_simulation_conditions_from_measurement_df()
)
model = model._set_parameter_mappings(simulation_conditions)
model = model._set_measurements(simulation_conditions)
return model._set_nominal_parameter_values()

Check warning on line 141 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L139-L141

Added lines #L139 - L141 were not covered by tests

@staticmethod
@abstractmethod
Expand Down Expand Up @@ -216,7 +235,7 @@ def _posteq(self, p, k, x, tcl):

def _eq(self, p, k, tcl, x0):
sol = diffrax.diffeqsolve(

Check warning on line 237 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L237

Added line #L237 was not covered by tests
self.term,
diffrax.ODETerm(self.xdot),
self.solver,
args=(p, k, tcl),
t0=0.0,
Expand All @@ -232,7 +251,7 @@ def _eq(self, p, k, tcl, x0):
def _solve(self, ts, p, k, x0, checkpointed):
tcl = self.tcl(x0, p, k)
sol = diffrax.diffeqsolve(
self.term,
diffrax.ODETerm(self.xdot),
self.solver,
args=(p, k, tcl),
t0=0.0,
Expand Down Expand Up @@ -264,15 +283,15 @@ def _loss(self, obs: jnp.ndarray, sigmay: jnp.ndarray, my: np.ndarray):
loss_fun = jax.vmap(self.Jy, in_axes=(0, 0, 0))
return -jnp.sum(loss_fun(obs, my, sigmay))

def _run(
def run_condition(
self,
ts: np.ndarray,
ts_dyn: np.ndarray,
ts: jnp.ndarray,
ts_dyn: jnp.ndarray,
p: jnp.ndarray,
k: np.ndarray,
k_preeq: np.ndarray,
k: jnp.ndarray,
k_preeq: jnp.ndarray,
my: jnp.ndarray,
pscale: np.ndarray,
pscale: jnp.ndarray,
checkpointed=True,
dynamic="true",
):
Expand Down Expand Up @@ -323,55 +342,55 @@ def _run(
return llh, (x_rdata, obs, stats)

@eqx.filter_jit
def run(
def _fun(
self,
ts: np.ndarray,
ts_dyn: np.ndarray,
ts: jnp.ndarray,
ts_dyn: jnp.ndarray,
p: jnp.ndarray,
k: np.ndarray,
k_preeq: np.ndarray,
my: np.ndarray,
pscale: np.ndarray,
k: jnp.ndarray,
k_preeq: jnp.ndarray,
my: jnp.ndarray,
pscale: jnp.ndarray,
dynamic="true",
):
return self._run(
return self.run_condition(
ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic
)

@eqx.filter_jit
def srun(
def _grad(
self,
ts: np.ndarray,
ts_dyn: np.ndarray,
ts: jnp.ndarray,
ts_dyn: jnp.ndarray,
p: jnp.ndarray,
k: np.ndarray,
k_preeq: np.ndarray,
my: np.ndarray,
pscale: np.ndarray,
k: jnp.ndarray,
k_preeq: jnp.ndarray,
my: jnp.ndarray,
pscale: jnp.ndarray,
dynamic="true",
):
(llh, (x, obs, stats)), sllh = (
jax.value_and_grad(self._run, 2, True)
jax.value_and_grad(self.run_condition, 2, True)
)(ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic)
return llh, sllh, (x, obs, stats)

@eqx.filter_jit
def s2run(
def _hessian(
self,
ts: np.ndarray,
ts_dyn: np.ndarray,
ts: jnp.ndarray,
ts_dyn: jnp.ndarray,
p: jnp.ndarray,
k: np.ndarray,
k_preeq: np.ndarray,
my: np.ndarray,
pscale: np.ndarray,
k: jnp.ndarray,
k_preeq: jnp.ndarray,
my: jnp.ndarray,
pscale: jnp.ndarray,
dynamic="true",
):
(llh, (x, obs, stats)), sllh = (
jax.value_and_grad(self._run, 2, True)
jax.value_and_grad(self.run_condition, 2, True)
)(ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic)

s2llh = jax.hessian(self._run, 2, True)(
s2llh = jax.hessian(self.run_condition, 2, True)(
ts,
ts_dyn,
p,
Expand All @@ -391,16 +410,7 @@ def run_simulation(
sensitivity_order: amici.SensitivityOrder = amici.SensitivityOrder.none,
):
parameter_mapping = self.parameter_mappings[simulation_condition]
measurements_df = self.petab_problem.measurement_df
for v, k in zip(
simulation_condition,
(
petab.SIMULATION_CONDITION_ID,
petab.PREEQUILIBRATION_CONDITION_ID,
),
):
measurements_df = measurements_df.query(f"{k} == '{v}'")
ts = _get_timepoints_with_replicates(measurements_df)
ts, my = self.measurements[simulation_condition]
p = jnp.array(

Check warning on line 414 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L412-L414

Added lines #L412 - L414 were not covered by tests
[
pval
Expand Down Expand Up @@ -430,10 +440,7 @@ def run_simulation(
if k in parameter_mapping.map_preeq_fix
]
)
my = _get_measurements_and_sigmas(
measurements_df, ts, self.observable_ids
)[0].flatten()
ts = np.array(ts)

ts_dyn = ts[np.isfinite(ts)]
dynamic = "true" if len(ts_dyn) and np.max(ts_dyn) > 0 else "false"

Check warning on line 445 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L444-L445

Added lines #L444 - L445 were not covered by tests

Expand All @@ -445,15 +452,15 @@ def run_simulation(
(

Check warning on line 452 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L451-L452

Added lines #L451 - L452 were not covered by tests
rdata_kwargs["llh"],
(rdata_kwargs["x"], rdata_kwargs["y"], rdata_kwargs["stats"]),
) = self.run(
) = self._fun(
ts, ts_dyn, p, k_sim, k_preeq, my, pscale, dynamic=dynamic
)
elif sensitivity_order == amici.SensitivityOrder.first:
(

Check warning on line 459 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L458-L459

Added lines #L458 - L459 were not covered by tests
rdata_kwargs["llh"],
rdata_kwargs["sllh"],
(rdata_kwargs["x"], rdata_kwargs["y"], rdata_kwargs["stats"]),
) = self.srun(
) = self._grad(
ts, ts_dyn, p, k_sim, k_preeq, my, pscale, dynamic=dynamic
)
elif sensitivity_order == amici.SensitivityOrder.second:
Expand All @@ -462,7 +469,7 @@ def run_simulation(
rdata_kwargs["sllh"],
rdata_kwargs["s2llh"],
(rdata_kwargs["x"], rdata_kwargs["y"], rdata_kwargs["stats"]),
) = self.s2run(
) = self._hessian(
ts, ts_dyn, p, k_sim, k_preeq, my, pscale, dynamic=dynamic
)

Expand Down
6 changes: 3 additions & 3 deletions python/tests/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,20 +151,20 @@ def check_fields_jax(
(
r_jax["llh"],
(r_jax["x"], r_jax["y"], r_jax["stats"]),
) = jax_model.run(**kwargs)
) = jax_model._fun(**kwargs)
elif sensi_order == amici.SensitivityOrder.first:
(
r_jax["llh"],
r_jax["sllh"],
(r_jax["x"], r_jax["y"], r_jax["stats"]),
) = jax_model.srun(**kwargs)
) = jax_model._grad(**kwargs)
elif sensi_order == amici.SensitivityOrder.second:
(
r_jax["llh"],
r_jax["sllh"],
r_jax["s2llh"],
(r_jax["x"], r_jax["y"], r_jax["stats"]),
) = jax_model.s2run(**kwargs)
) = jax_model._hessian(**kwargs)

for field in fields:
for r_amici, r_jax in zip(rs_amici, [r_jax]):
Expand Down

0 comments on commit a46e65d

Please sign in to comment.