diff --git a/python/sdist/amici/de_export.py b/python/sdist/amici/de_export.py index 6f747d8d82..d773b0864e 100644 --- a/python/sdist/amici/de_export.py +++ b/python/sdist/amici/de_export.py @@ -344,6 +344,14 @@ def jnp_stack_str(array) -> str: else "_" for sym_name in sym_names }, + **{ + f"{sym_name.upper()}_IDS": "".join( + f'"{strip_pysb(s)}", ' for s in self.model.sym(sym_name) + ) + if self.model.sym(sym_name) + else "tuple()" + for sym_name in ("p", "k", "y", "x") + }, **{ "MODEL_NAME": self.model_name, }, diff --git a/python/sdist/amici/jax.py b/python/sdist/amici/jax.py index 74e601dd8c..5d70a08aef 100644 --- a/python/sdist/amici/jax.py +++ b/python/sdist/amici/jax.py @@ -1,15 +1,25 @@ from abc import abstractmethod from dataclasses import dataclass from concurrent.futures import ThreadPoolExecutor +from numbers import Number import diffrax import equinox as eqx import jax.numpy as jnp import numpy as np +import pandas as pd import jax -from collections.abc import Iterable +import petab.v1 as petab import amici +from amici.petab.parameter_mapping import ( + ParameterMapping, + ParameterMappingForCondition, +) +from amici.petab.conditions import ( + _get_timepoints_with_replicates, + _get_measurements_and_sigmas, +) jax.config.update("jax_enable_x64", True) @@ -83,6 +93,22 @@ def sigmay(y, p, k): ... @abstractmethod def Jy(y, my, sigmay): ... + @property + @abstractmethod + def state_ids(self): ... + + @property + @abstractmethod + def observable_ids(self): ... + + @property + @abstractmethod + def parameter_ids(self): ... + + @property + @abstractmethod + def fixed_parameter_ids(self): ... + def unscale_p(self, p, pscale): return jax.vmap( lambda p_i, pscale_i: jnp.stack( @@ -154,9 +180,9 @@ def _run( self, ts: np.ndarray, ts_dyn: np.ndarray, - p: np.ndarray, - k: jnp.ndarray, - k_preeq: jnp.ndarray, + p: jnp.ndarray, + k: np.ndarray, + k_preeq: np.ndarray, my: jnp.ndarray, pscale: np.ndarray, checkpointed=True, @@ -272,14 +298,50 @@ def s2run( return llh, sllh, s2llh, (x, obs, stats) def run_simulation( - self, edata: amici.ExpData, sensitivity_order: amici.SensitivityOrder + self, + parameter_mapping: ParameterMappingForCondition = None, + measurements: pd.DataFrame = None, + parameters: pd.DataFrame = None, + sensitivity_order: amici.SensitivityOrder = amici.SensitivityOrder.none, ): - ts = np.asarray(edata.getTimepoints()) - p = jnp.asarray(edata.parameters) - k = np.asarray(edata.fixedParameters) - k_preeq = np.asarray(edata.fixedParametersPreequilibration) - my = np.asarray(edata.getObservedData()) - pscale = np.asarray(edata.pscale) + cond_id, measurements_df = measurements + ts = _get_timepoints_with_replicates(measurements_df) + p = jnp.array( + [ + pval + if isinstance( + pval := parameter_mapping.map_sim_var[par], Number + ) + else petab.scale( + parameters.loc[pval, petab.NOMINAL_VALUE], + parameters.loc[pval, petab.PARAMETER_SCALE], + ) + for par in self.parameter_ids + ] + ) + pscale = jnp.array( + [ + 0 if s == petab.LIN else 1 if s == petab.LOG else 2 + for s in parameter_mapping.scale_map_sim_var.values() + ] + ) + k_sim = np.array( + [ + parameter_mapping.map_sim_fix[k] + for k in self.fixed_parameter_ids + ] + ) + k_preeq = np.array( + [ + parameter_mapping.map_preeq_fix[k] + for k in self.fixed_parameter_ids + if k in parameter_mapping.map_preeq_fix + ] + ) + my = _get_measurements_and_sigmas( + measurements_df, ts, self.observable_ids + )[0].flatten() + ts = np.array(ts) ts_dyn = ts[np.isfinite(ts)] dynamic = "true" if len(ts_dyn) and np.max(ts_dyn) > 0 else "false" @@ -290,7 +352,7 @@ def run_simulation( rdata_kwargs["llh"], (rdata_kwargs["x"], rdata_kwargs["y"], rdata_kwargs["stats"]), ) = self.run( - ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic + ts, ts_dyn, p, k_sim, k_preeq, my, pscale, dynamic=dynamic ) elif sensitivity_order == amici.SensitivityOrder.first: ( @@ -298,7 +360,7 @@ def run_simulation( rdata_kwargs["sllh"], (rdata_kwargs["x"], rdata_kwargs["y"], rdata_kwargs["stats"]), ) = self.srun( - ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic + ts, ts_dyn, p, k_sim, k_preeq, my, pscale, dynamic=dynamic ) elif sensitivity_order == amici.SensitivityOrder.second: ( @@ -307,7 +369,7 @@ def run_simulation( rdata_kwargs["s2llh"], (rdata_kwargs["x"], rdata_kwargs["y"], rdata_kwargs["stats"]), ) = self.s2run( - ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic + ts, ts_dyn, p, k_sim, k_preeq, my, pscale, dynamic=dynamic ) for field in rdata_kwargs.keys(): @@ -324,18 +386,47 @@ def run_simulation( def run_simulations( self, - edatas: Iterable[amici.ExpData], sensitivity_order: amici.SensitivityOrder = amici.SensitivityOrder.none, num_threads: int = 1, + parameter_mappings: ParameterMapping = None, + parameters: pd.DataFrame = None, + simulation_conditions: pd.DataFrame = None, + measurements: pd.DataFrame = None, ): fun = eqx.Partial( - self.run_simulation, sensitivity_order=sensitivity_order + self.run_simulation, + sensitivity_order=sensitivity_order, + parameters=parameters, ) + gb = ( + [ + petab.PREEQUILIBRATION_CONDITION_ID, + petab.SIMULATION_CONDITION_ID, + ] + if petab.PREEQUILIBRATION_CONDITION_ID in measurements.columns + and petab.PREEQUILIBRATION_CONDITION_ID in simulation_conditions + else petab.SIMULATION_CONDITION_ID + ) + + per_condition_measurements = measurements.groupby(gb) + + order_conditions = [ + tuple(c) if isinstance(c, np.ndarray) else c + for c in simulation_conditions[gb].values + ] + + sorted_mappings = [ + parameter_mappings[order_conditions.index(condition)] + for condition in per_condition_measurements.groups.keys() + ] + if num_threads > 1: with ThreadPoolExecutor(max_workers=num_threads) as pool: - results = pool.map(fun, edatas) + results = pool.map( + fun, sorted_mappings, per_condition_measurements + ) else: - results = map(fun, edatas) + results = map(fun, sorted_mappings, per_condition_measurements) return list(results) diff --git a/tests/benchmark-models/test_petab_benchmark.py b/tests/benchmark-models/test_petab_benchmark.py index 58586e3329..54d92dcf88 100644 --- a/tests/benchmark-models/test_petab_benchmark.py +++ b/tests/benchmark-models/test_petab_benchmark.py @@ -31,8 +31,6 @@ RDATAS, rdatas_to_measurement_df, simulate_petab, - create_edatas, - fill_in_parameters, create_parameter_mapping, ) from petab.v1.visualize import plot_problem @@ -292,31 +290,19 @@ def test_jax_llh(benchmark_problem): simulation_conditions = ( petab_problem.get_simulation_conditions_from_measurement_df() ) - edatas = create_edatas( - amici_model=amici_model, - petab_problem=petab_problem, - simulation_conditions=simulation_conditions, - ) - problem_parameters = { - t.Index: getattr(t, petab.NOMINAL_VALUE) - for t in petab_problem.parameter_df.itertuples() - } - parameter_mapping = create_parameter_mapping( + mappings = create_parameter_mapping( petab_problem=petab_problem, simulation_conditions=simulation_conditions, scaled_parameters=False, amici_model=amici_model, ) - fill_in_parameters( - edatas=edatas, - problem_parameters=problem_parameters, - scaled_parameters=False, - parameter_mapping=parameter_mapping, - amici_model=amici_model, + rdatas_jax = jax_model.run_simulations( + parameter_mappings=mappings, + parameters=petab_problem.parameter_df, + simulation_conditions=simulation_conditions, + measurements=petab_problem.measurement_df, ) - rdatas_jax = jax_model.run_simulations(edatas) - llh_jax = sum(r.llh for r in rdatas_jax) assert np.isclose(