diff --git a/python/sdist/amici/jax.py b/python/sdist/amici/jax.py index 5ad11680c9..fc16a533e1 100644 --- a/python/sdist/amici/jax.py +++ b/python/sdist/amici/jax.py @@ -7,7 +7,6 @@ import equinox as eqx import jax.numpy as jnp import numpy as np -import pandas as pd import jax import petab.v1 as petab @@ -39,6 +38,7 @@ class JAXModel(eqx.Module): dcoeff: float maxsteps: int parameters: jnp.ndarray + parameter_mappings: dict[tuple[str], ParameterMappingForCondition] term: diffrax.ODETerm petab_problem: petab.Problem | None @@ -59,6 +59,7 @@ def __init__(self): ) self.term = diffrax.ODETerm(self.xdot) self.petab_problem = None + self.parameter_mappings = None self.parameters = jnp.array([]) def set_petab_problem(self, petab_problem: petab.Problem) -> "JAXModel": @@ -68,15 +69,41 @@ def set_petab_problem(self, petab_problem: petab.Problem) -> "JAXModel": Petab problem to set. :return: JAXModel instance """ - if self.petab_problem is None: - model = eqx.tree_at( - lambda x: x.petab_problem, - self, - petab_problem, - is_leaf=lambda x: x is None, + + is_leaf = lambda x: x is None if self.petab_problem is None else None # noqa: E731 + model = eqx.tree_at( + lambda x: x.petab_problem, + self, + petab_problem, + is_leaf=is_leaf, + ) + + simulation_conditions = ( + petab_problem.get_simulation_conditions_from_measurement_df() + ) + + mappings = create_parameter_mapping( + petab_problem=petab_problem, + simulation_conditions=simulation_conditions, + scaled_parameters=False, + amici_model=self, + ) + + parameter_mappings = { + tuple(simulation_condition.values): mapping + for (_, simulation_condition), mapping in zip( + simulation_conditions.iterrows(), mappings ) - else: - model = eqx.tree_at(lambda x: x.petab_problem, self, petab_problem) + } + is_leaf = ( # noqa: E731 + lambda x: x is None if self.parameter_mappings is None else None + ) + model = eqx.tree_at( + lambda x: x.parameter_mappings, + model, + parameter_mappings, + is_leaf=is_leaf, + ) nominal_values = jnp.array( [ @@ -360,11 +387,19 @@ def s2run( def run_simulation( self, - parameter_mapping: ParameterMappingForCondition = None, - measurements: pd.DataFrame = None, + simulation_condition: tuple[str], sensitivity_order: amici.SensitivityOrder = amici.SensitivityOrder.none, ): - cond_id, measurements_df = measurements + parameter_mapping = self.parameter_mappings[simulation_condition] + measurements_df = self.petab_problem.measurement_df + for v, k in zip( + simulation_condition, + ( + petab.SIMULATION_CONDITION_ID, + petab.PREEQUILIBRATION_CONDITION_ID, + ), + ): + measurements_df = measurements_df.query(f"{k} == '{v}'") ts = _get_timepoints_with_replicates(measurements_df) p = jnp.array( [ @@ -402,7 +437,9 @@ def run_simulation( ts_dyn = ts[np.isfinite(ts)] dynamic = "true" if len(ts_dyn) and np.max(ts_dyn) > 0 else "false" - rdata_kwargs = dict() + rdata_kwargs = dict( + simulation_condition=simulation_condition, + ) if sensitivity_order == amici.SensitivityOrder.none: ( @@ -445,56 +482,24 @@ def run_simulations( self, sensitivity_order: amici.SensitivityOrder = amici.SensitivityOrder.none, num_threads: int = 1, - simulation_conditions: pd.DataFrame = None, + simulation_conditions: tuple[tuple[str]] = None, ): fun = eqx.Partial( self.run_simulation, sensitivity_order=sensitivity_order, ) - gb = ( - [ - petab.PREEQUILIBRATION_CONDITION_ID, - petab.SIMULATION_CONDITION_ID, - ] - if petab.PREEQUILIBRATION_CONDITION_ID - in self.petab_problem.measurement_df - and petab.PREEQUILIBRATION_CONDITION_ID in simulation_conditions - else petab.SIMULATION_CONDITION_ID - ) - - per_condition_measurements = self.petab_problem.measurement_df.groupby( - gb - ) - - order_conditions = [ - tuple(c) if isinstance(c, np.ndarray) else c - for c in simulation_conditions[gb].values - ] - - parameter_mappings = create_parameter_mapping( - petab_problem=self.petab_problem, - simulation_conditions=simulation_conditions, - scaled_parameters=False, - amici_model=self, - ) - - sorted_mappings = [ - parameter_mappings[order_conditions.index(condition)] - for condition in per_condition_measurements.groups.keys() - ] if num_threads > 1: with ThreadPoolExecutor(max_workers=num_threads) as pool: - results = pool.map( - fun, sorted_mappings, per_condition_measurements - ) + results = pool.map(fun, simulation_conditions) else: - results = map(fun, sorted_mappings, per_condition_measurements) + results = map(fun, simulation_conditions) return list(results) @dataclass class ReturnDataJAX(dict): + simulation_condition: tuple[str] = None x: np.array = None y: np.array = None sigmay: np.array = None diff --git a/tests/benchmark-models/test_petab_benchmark.py b/tests/benchmark-models/test_petab_benchmark.py index 6667a6aae3..97d96af324 100644 --- a/tests/benchmark-models/test_petab_benchmark.py +++ b/tests/benchmark-models/test_petab_benchmark.py @@ -290,6 +290,9 @@ def test_jax_llh(benchmark_problem): simulation_conditions = ( petab_problem.get_simulation_conditions_from_measurement_df() ) + simulation_conditions = tuple( + tuple(row) for _, row in simulation_conditions.iterrows() + ) rdatas_jax = jax_model.run_simulations( simulation_conditions=simulation_conditions, )