Skip to content

Commit

Permalink
add parameter values to model class
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Nov 12, 2024
1 parent a64f89b commit 7292451
Showing 1 changed file with 38 additions and 15 deletions.
53 changes: 38 additions & 15 deletions python/sdist/amici/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class JAXModel(eqx.Module):
icoeff: float
dcoeff: float
maxsteps: int
parameters: jnp.ndarray
term: diffrax.ODETerm
petab_problem: petab.Problem | None

Expand All @@ -58,17 +59,40 @@ def __init__(self):
)
self.term = diffrax.ODETerm(self.xdot)
self.petab_problem = None
self.parameters = jnp.array([])

Check warning on line 62 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L60-L62

Added lines #L60 - L62 were not covered by tests

def set_petab_problem(self, petab_problem: petab.Problem) -> "JAXModel":
"""
Set the PEtab problem for the model and updates parameters to the nominal values.
:param petab_problem:
Petab problem to set.
:return: JAXModel instance
"""
if self.petab_problem is None:
return eqx.tree_at(
model = eqx.tree_at(

Check warning on line 72 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L71-L72

Added lines #L71 - L72 were not covered by tests
lambda x: x.petab_problem,
self,
petab_problem,
is_leaf=lambda x: x is None,
)
else:
return eqx.tree_at(lambda x: x.petab_problem, self, petab_problem)
model = eqx.tree_at(lambda x: x.petab_problem, self, petab_problem)

Check warning on line 79 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L79

Added line #L79 was not covered by tests

nominal_values = jnp.array(

Check warning on line 81 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L81

Added line #L81 was not covered by tests
[
petab.scale(
model.petab_problem.parameter_df.loc[
pval, petab.NOMINAL_VALUE
],
model.petab_problem.parameter_df.loc[
pval, petab.PARAMETER_SCALE
],
)
for pval in model.petab_parameter_ids()
]
)

return eqx.tree_at(lambda x: x.parameters, model, nominal_values)

Check warning on line 95 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L95

Added line #L95 was not covered by tests

@staticmethod
@abstractmethod
Expand Down Expand Up @@ -138,7 +162,15 @@ def getFixedParameterIds(self) -> list[str]: # noqa: N802
"""
return self.fixed_parameter_ids

Check warning on line 163 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L163

Added line #L163 was not covered by tests

def unscale_p(self, p, pscale):
def petab_parameter_ids(self) -> list[str]:
return self.petab_problem.parameter_df[

Check warning on line 166 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L166

Added line #L166 was not covered by tests
self.petab_problem.parameter_df[petab.ESTIMATE] == 1
].index.tolist()

def get_petab_parameter_by_name(self, name: str) -> jnp.float_:
return self.parameters[self.petab_parameter_ids().index(name)]

Check warning on line 171 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L171

Added line #L171 was not covered by tests

def _unscale_p(self, p, pscale):
return jax.vmap(

Check warning on line 174 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L174

Added line #L174 was not covered by tests
lambda p_i, pscale_i: jnp.stack(
(p_i, jnp.exp(p_i), jnp.power(10, p_i))
Expand Down Expand Up @@ -217,7 +249,7 @@ def _run(
checkpointed=True,
dynamic="true",
):
ps = self.unscale_p(p, pscale)
ps = self._unscale_p(p, pscale)

Check warning on line 252 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L252

Added line #L252 was not covered by tests

# Pre-equilibration
if k_preeq.shape[0] > 0:
Expand Down Expand Up @@ -340,14 +372,7 @@ def run_simulation(
if isinstance(
pval := parameter_mapping.map_sim_var[par], Number
)
else petab.scale(
self.petab_problem.parameter_df.loc[
pval, petab.NOMINAL_VALUE
],
self.petab_problem.parameter_df.loc[
pval, petab.PARAMETER_SCALE
],
)
else self.get_petab_parameter_by_name(pval)
for par in self.parameter_ids
]
)
Expand Down Expand Up @@ -471,13 +496,11 @@ def run_simulations(
@dataclass
class ReturnDataJAX(dict):
x: np.array = None
sx: np.array = None
y: np.array = None
sy: np.array = None
sigmay: np.array = None
ssigmay: np.array = None
llh: np.array = None
sllh: np.array = None
s2llh: np.array = None
stats: dict = None

def __init__(self, *args, **kwargs):
Expand Down

0 comments on commit 7292451

Please sign in to comment.