Skip to content

Commit

Permalink
refactor parameter mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Nov 12, 2024
1 parent 7292451 commit da02106
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 49 deletions.
103 changes: 54 additions & 49 deletions python/sdist/amici/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import equinox as eqx
import jax.numpy as jnp
import numpy as np
import pandas as pd
import jax
import petab.v1 as petab

Expand Down Expand Up @@ -39,6 +38,7 @@ class JAXModel(eqx.Module):
dcoeff: float
maxsteps: int
parameters: jnp.ndarray
parameter_mappings: dict[tuple[str], ParameterMappingForCondition]
term: diffrax.ODETerm
petab_problem: petab.Problem | None

Expand All @@ -59,6 +59,7 @@ def __init__(self):
)
self.term = diffrax.ODETerm(self.xdot)
self.petab_problem = None
self.parameter_mappings = None
self.parameters = jnp.array([])

Check warning on line 63 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L60-L63

Added lines #L60 - L63 were not covered by tests

def set_petab_problem(self, petab_problem: petab.Problem) -> "JAXModel":
Expand All @@ -68,15 +69,41 @@ def set_petab_problem(self, petab_problem: petab.Problem) -> "JAXModel":
Petab problem to set.
:return: JAXModel instance
"""
if self.petab_problem is None:
model = eqx.tree_at(
lambda x: x.petab_problem,
self,
petab_problem,
is_leaf=lambda x: x is None,

is_leaf = lambda x: x is None if self.petab_problem is None else None # noqa: E731
model = eqx.tree_at(

Check warning on line 74 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L73-L74

Added lines #L73 - L74 were not covered by tests
lambda x: x.petab_problem,
self,
petab_problem,
is_leaf=is_leaf,
)

simulation_conditions = (

Check warning on line 81 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L81

Added line #L81 was not covered by tests
petab_problem.get_simulation_conditions_from_measurement_df()
)

mappings = create_parameter_mapping(

Check warning on line 85 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L85

Added line #L85 was not covered by tests
petab_problem=petab_problem,
simulation_conditions=simulation_conditions,
scaled_parameters=False,
amici_model=self,
)

parameter_mappings = {

Check warning on line 92 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L92

Added line #L92 was not covered by tests
tuple(simulation_condition.values): mapping
for (_, simulation_condition), mapping in zip(
simulation_conditions.iterrows(), mappings
)
else:
model = eqx.tree_at(lambda x: x.petab_problem, self, petab_problem)
}
is_leaf = ( # noqa: E731

Check warning on line 98 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L98

Added line #L98 was not covered by tests
lambda x: x is None if self.parameter_mappings is None else None
)
model = eqx.tree_at(

Check warning on line 101 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L101

Added line #L101 was not covered by tests
lambda x: x.parameter_mappings,
model,
parameter_mappings,
is_leaf=is_leaf,
)

nominal_values = jnp.array(

Check warning on line 108 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L108

Added line #L108 was not covered by tests
[
Expand Down Expand Up @@ -360,11 +387,19 @@ def s2run(

def run_simulation(
self,
parameter_mapping: ParameterMappingForCondition = None,
measurements: pd.DataFrame = None,
simulation_condition: tuple[str],
sensitivity_order: amici.SensitivityOrder = amici.SensitivityOrder.none,
):
cond_id, measurements_df = measurements
parameter_mapping = self.parameter_mappings[simulation_condition]
measurements_df = self.petab_problem.measurement_df
for v, k in zip(

Check warning on line 395 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L393-L395

Added lines #L393 - L395 were not covered by tests
simulation_condition,
(
petab.SIMULATION_CONDITION_ID,
petab.PREEQUILIBRATION_CONDITION_ID,
),
):
measurements_df = measurements_df.query(f"{k} == '{v}'")
ts = _get_timepoints_with_replicates(measurements_df)
p = jnp.array(

Check warning on line 404 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L402-L404

Added lines #L402 - L404 were not covered by tests
[
Expand Down Expand Up @@ -402,7 +437,9 @@ def run_simulation(
ts_dyn = ts[np.isfinite(ts)]
dynamic = "true" if len(ts_dyn) and np.max(ts_dyn) > 0 else "false"

Check warning on line 438 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L436-L438

Added lines #L436 - L438 were not covered by tests

rdata_kwargs = dict()
rdata_kwargs = dict(

Check warning on line 440 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L440

Added line #L440 was not covered by tests
simulation_condition=simulation_condition,
)

if sensitivity_order == amici.SensitivityOrder.none:
(

Check warning on line 445 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L444-L445

Added lines #L444 - L445 were not covered by tests
Expand Down Expand Up @@ -445,56 +482,24 @@ def run_simulations(
self,
sensitivity_order: amici.SensitivityOrder = amici.SensitivityOrder.none,
num_threads: int = 1,
simulation_conditions: pd.DataFrame = None,
simulation_conditions: tuple[tuple[str]] = None,
):
fun = eqx.Partial(

Check warning on line 487 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L487

Added line #L487 was not covered by tests
self.run_simulation,
sensitivity_order=sensitivity_order,
)
gb = (
[
petab.PREEQUILIBRATION_CONDITION_ID,
petab.SIMULATION_CONDITION_ID,
]
if petab.PREEQUILIBRATION_CONDITION_ID
in self.petab_problem.measurement_df
and petab.PREEQUILIBRATION_CONDITION_ID in simulation_conditions
else petab.SIMULATION_CONDITION_ID
)

per_condition_measurements = self.petab_problem.measurement_df.groupby(
gb
)

order_conditions = [
tuple(c) if isinstance(c, np.ndarray) else c
for c in simulation_conditions[gb].values
]

parameter_mappings = create_parameter_mapping(
petab_problem=self.petab_problem,
simulation_conditions=simulation_conditions,
scaled_parameters=False,
amici_model=self,
)

sorted_mappings = [
parameter_mappings[order_conditions.index(condition)]
for condition in per_condition_measurements.groups.keys()
]

if num_threads > 1:
with ThreadPoolExecutor(max_workers=num_threads) as pool:
results = pool.map(
fun, sorted_mappings, per_condition_measurements
)
results = pool.map(fun, simulation_conditions)

Check warning on line 494 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L492-L494

Added lines #L492 - L494 were not covered by tests
else:
results = map(fun, sorted_mappings, per_condition_measurements)
results = map(fun, simulation_conditions)
return list(results)

Check warning on line 497 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L496-L497

Added lines #L496 - L497 were not covered by tests


@dataclass
class ReturnDataJAX(dict):
simulation_condition: tuple[str] = None
x: np.array = None
y: np.array = None
sigmay: np.array = None
Expand Down
3 changes: 3 additions & 0 deletions tests/benchmark-models/test_petab_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,9 @@ def test_jax_llh(benchmark_problem):
simulation_conditions = (
petab_problem.get_simulation_conditions_from_measurement_df()
)
simulation_conditions = tuple(
tuple(row) for _, row in simulation_conditions.iterrows()
)
rdatas_jax = jax_model.run_simulations(
simulation_conditions=simulation_conditions,
)
Expand Down

0 comments on commit da02106

Please sign in to comment.