Skip to content

Commit

Permalink
refactor: remove use of edatas
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Nov 9, 2024
1 parent 51bd18c commit c7c5d4b
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 38 deletions.
8 changes: 8 additions & 0 deletions python/sdist/amici/de_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,14 @@ def jnp_stack_str(array) -> str:
else "_"
for sym_name in sym_names
},
**{
f"{sym_name.upper()}_IDS": "".join(
f'"{strip_pysb(s)}", ' for s in self.model.sym(sym_name)
)
if self.model.sym(sym_name)
else "tuple()"
for sym_name in ("p", "k", "y", "x")
},
**{
"MODEL_NAME": self.model_name,
},
Expand Down
127 changes: 109 additions & 18 deletions python/sdist/amici/jax.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
from abc import abstractmethod
from dataclasses import dataclass
from concurrent.futures import ThreadPoolExecutor
from numbers import Number

import diffrax
import equinox as eqx
import jax.numpy as jnp
import numpy as np
import pandas as pd
import jax
from collections.abc import Iterable
import petab.v1 as petab

import amici
from amici.petab.parameter_mapping import (
ParameterMapping,
ParameterMappingForCondition,
)
from amici.petab.conditions import (
_get_timepoints_with_replicates,
_get_measurements_and_sigmas,
)

jax.config.update("jax_enable_x64", True)

Expand Down Expand Up @@ -83,6 +93,22 @@ def sigmay(y, p, k): ...
@abstractmethod
def Jy(y, my, sigmay): ...

@property
@abstractmethod
def state_ids(self): ...

@property
@abstractmethod
def observable_ids(self): ...

@property
@abstractmethod
def parameter_ids(self): ...

@property
@abstractmethod
def fixed_parameter_ids(self): ...

def unscale_p(self, p, pscale):
return jax.vmap(
lambda p_i, pscale_i: jnp.stack(
Expand Down Expand Up @@ -154,9 +180,9 @@ def _run(
self,
ts: np.ndarray,
ts_dyn: np.ndarray,
p: np.ndarray,
k: jnp.ndarray,
k_preeq: jnp.ndarray,
p: jnp.ndarray,
k: np.ndarray,
k_preeq: np.ndarray,
my: jnp.ndarray,
pscale: np.ndarray,
checkpointed=True,
Expand Down Expand Up @@ -272,14 +298,50 @@ def s2run(
return llh, sllh, s2llh, (x, obs, stats)

def run_simulation(
self, edata: amici.ExpData, sensitivity_order: amici.SensitivityOrder
self,
parameter_mapping: ParameterMappingForCondition = None,
measurements: pd.DataFrame = None,
parameters: pd.DataFrame = None,
sensitivity_order: amici.SensitivityOrder = amici.SensitivityOrder.none,
):
ts = np.asarray(edata.getTimepoints())
p = jnp.asarray(edata.parameters)
k = np.asarray(edata.fixedParameters)
k_preeq = np.asarray(edata.fixedParametersPreequilibration)
my = np.asarray(edata.getObservedData())
pscale = np.asarray(edata.pscale)
cond_id, measurements_df = measurements
ts = _get_timepoints_with_replicates(measurements_df)
p = jnp.array(
[
pval
if isinstance(
pval := parameter_mapping.map_sim_var[par], Number
)
else petab.scale(
parameters.loc[pval, petab.NOMINAL_VALUE],
parameters.loc[pval, petab.PARAMETER_SCALE],
)
for par in self.parameter_ids
]
)
pscale = jnp.array(
[
0 if s == petab.LIN else 1 if s == petab.LOG else 2
for s in parameter_mapping.scale_map_sim_var.values()
]
)
k_sim = np.array(
[
parameter_mapping.map_sim_fix[k]
for k in self.fixed_parameter_ids
]
)
k_preeq = np.array(
[
parameter_mapping.map_preeq_fix[k]
for k in self.fixed_parameter_ids
if k in parameter_mapping.map_preeq_fix
]
)
my = _get_measurements_and_sigmas(
measurements_df, ts, self.observable_ids
)[0].flatten()
ts = np.array(ts)
ts_dyn = ts[np.isfinite(ts)]
dynamic = "true" if len(ts_dyn) and np.max(ts_dyn) > 0 else "false"

Expand All @@ -290,15 +352,15 @@ def run_simulation(
rdata_kwargs["llh"],
(rdata_kwargs["x"], rdata_kwargs["y"], rdata_kwargs["stats"]),
) = self.run(
ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic
ts, ts_dyn, p, k_sim, k_preeq, my, pscale, dynamic=dynamic
)
elif sensitivity_order == amici.SensitivityOrder.first:
(
rdata_kwargs["llh"],
rdata_kwargs["sllh"],
(rdata_kwargs["x"], rdata_kwargs["y"], rdata_kwargs["stats"]),
) = self.srun(
ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic
ts, ts_dyn, p, k_sim, k_preeq, my, pscale, dynamic=dynamic
)
elif sensitivity_order == amici.SensitivityOrder.second:
(
Expand All @@ -307,7 +369,7 @@ def run_simulation(
rdata_kwargs["s2llh"],
(rdata_kwargs["x"], rdata_kwargs["y"], rdata_kwargs["stats"]),
) = self.s2run(
ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic
ts, ts_dyn, p, k_sim, k_preeq, my, pscale, dynamic=dynamic
)

for field in rdata_kwargs.keys():
Expand All @@ -324,18 +386,47 @@ def run_simulation(

def run_simulations(
self,
edatas: Iterable[amici.ExpData],
sensitivity_order: amici.SensitivityOrder = amici.SensitivityOrder.none,
num_threads: int = 1,
parameter_mappings: ParameterMapping = None,
parameters: pd.DataFrame = None,
simulation_conditions: pd.DataFrame = None,
measurements: pd.DataFrame = None,
):
fun = eqx.Partial(
self.run_simulation, sensitivity_order=sensitivity_order
self.run_simulation,
sensitivity_order=sensitivity_order,
parameters=parameters,
)
gb = (
[
petab.PREEQUILIBRATION_CONDITION_ID,
petab.SIMULATION_CONDITION_ID,
]
if petab.PREEQUILIBRATION_CONDITION_ID in measurements.columns
and petab.PREEQUILIBRATION_CONDITION_ID in simulation_conditions
else petab.SIMULATION_CONDITION_ID
)

per_condition_measurements = measurements.groupby(gb)

order_conditions = [
tuple(c) if isinstance(c, np.ndarray) else c
for c in simulation_conditions[gb].values
]

sorted_mappings = [
parameter_mappings[order_conditions.index(condition)]
for condition in per_condition_measurements.groups.keys()
]

if num_threads > 1:
with ThreadPoolExecutor(max_workers=num_threads) as pool:
results = pool.map(fun, edatas)
results = pool.map(
fun, sorted_mappings, per_condition_measurements
)
else:
results = map(fun, edatas)
results = map(fun, sorted_mappings, per_condition_measurements)
return list(results)


Expand Down
26 changes: 6 additions & 20 deletions tests/benchmark-models/test_petab_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@
RDATAS,
rdatas_to_measurement_df,
simulate_petab,
create_edatas,
fill_in_parameters,
create_parameter_mapping,
)
from petab.v1.visualize import plot_problem
Expand Down Expand Up @@ -292,31 +290,19 @@ def test_jax_llh(benchmark_problem):
simulation_conditions = (
petab_problem.get_simulation_conditions_from_measurement_df()
)
edatas = create_edatas(
amici_model=amici_model,
petab_problem=petab_problem,
simulation_conditions=simulation_conditions,
)
problem_parameters = {
t.Index: getattr(t, petab.NOMINAL_VALUE)
for t in petab_problem.parameter_df.itertuples()
}
parameter_mapping = create_parameter_mapping(
mappings = create_parameter_mapping(
petab_problem=petab_problem,
simulation_conditions=simulation_conditions,
scaled_parameters=False,
amici_model=amici_model,
)
fill_in_parameters(
edatas=edatas,
problem_parameters=problem_parameters,
scaled_parameters=False,
parameter_mapping=parameter_mapping,
amici_model=amici_model,
rdatas_jax = jax_model.run_simulations(
parameter_mappings=mappings,
parameters=petab_problem.parameter_df,
simulation_conditions=simulation_conditions,
measurements=petab_problem.measurement_df,
)

rdatas_jax = jax_model.run_simulations(edatas)

llh_jax = sum(r.llh for r in rdatas_jax)

assert np.isclose(
Expand Down

0 comments on commit c7c5d4b

Please sign in to comment.