From 72924518fd20e45c5ab986ee5f741997aaab9694 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Tue, 12 Nov 2024 16:28:39 +0000 Subject: [PATCH] add parameter values to model class --- python/sdist/amici/jax.py | 53 ++++++++++++++++++++++++++++----------- 1 file changed, 38 insertions(+), 15 deletions(-) diff --git a/python/sdist/amici/jax.py b/python/sdist/amici/jax.py index 3597404cea..5ad11680c9 100644 --- a/python/sdist/amici/jax.py +++ b/python/sdist/amici/jax.py @@ -38,6 +38,7 @@ class JAXModel(eqx.Module): icoeff: float dcoeff: float maxsteps: int + parameters: jnp.ndarray term: diffrax.ODETerm petab_problem: petab.Problem | None @@ -58,17 +59,40 @@ def __init__(self): ) self.term = diffrax.ODETerm(self.xdot) self.petab_problem = None + self.parameters = jnp.array([]) def set_petab_problem(self, petab_problem: petab.Problem) -> "JAXModel": + """ + Set the PEtab problem for the model and updates parameters to the nominal values. + :param petab_problem: + Petab problem to set. + :return: JAXModel instance + """ if self.petab_problem is None: - return eqx.tree_at( + model = eqx.tree_at( lambda x: x.petab_problem, self, petab_problem, is_leaf=lambda x: x is None, ) else: - return eqx.tree_at(lambda x: x.petab_problem, self, petab_problem) + model = eqx.tree_at(lambda x: x.petab_problem, self, petab_problem) + + nominal_values = jnp.array( + [ + petab.scale( + model.petab_problem.parameter_df.loc[ + pval, petab.NOMINAL_VALUE + ], + model.petab_problem.parameter_df.loc[ + pval, petab.PARAMETER_SCALE + ], + ) + for pval in model.petab_parameter_ids() + ] + ) + + return eqx.tree_at(lambda x: x.parameters, model, nominal_values) @staticmethod @abstractmethod @@ -138,7 +162,15 @@ def getFixedParameterIds(self) -> list[str]: # noqa: N802 """ return self.fixed_parameter_ids - def unscale_p(self, p, pscale): + def petab_parameter_ids(self) -> list[str]: + return self.petab_problem.parameter_df[ + self.petab_problem.parameter_df[petab.ESTIMATE] == 1 + ].index.tolist() + + def get_petab_parameter_by_name(self, name: str) -> jnp.float_: + return self.parameters[self.petab_parameter_ids().index(name)] + + def _unscale_p(self, p, pscale): return jax.vmap( lambda p_i, pscale_i: jnp.stack( (p_i, jnp.exp(p_i), jnp.power(10, p_i)) @@ -217,7 +249,7 @@ def _run( checkpointed=True, dynamic="true", ): - ps = self.unscale_p(p, pscale) + ps = self._unscale_p(p, pscale) # Pre-equilibration if k_preeq.shape[0] > 0: @@ -340,14 +372,7 @@ def run_simulation( if isinstance( pval := parameter_mapping.map_sim_var[par], Number ) - else petab.scale( - self.petab_problem.parameter_df.loc[ - pval, petab.NOMINAL_VALUE - ], - self.petab_problem.parameter_df.loc[ - pval, petab.PARAMETER_SCALE - ], - ) + else self.get_petab_parameter_by_name(pval) for par in self.parameter_ids ] ) @@ -471,13 +496,11 @@ def run_simulations( @dataclass class ReturnDataJAX(dict): x: np.array = None - sx: np.array = None y: np.array = None - sy: np.array = None sigmay: np.array = None - ssigmay: np.array = None llh: np.array = None sllh: np.array = None + s2llh: np.array = None stats: dict = None def __init__(self, *args, **kwargs):