Skip to content

Commit

Permalink
implement jax-based reinitialisation
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Dec 5, 2024
1 parent 9d82a6c commit ee84559
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 19 deletions.
29 changes: 15 additions & 14 deletions python/sdist/amici/jax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,8 @@ def simulate_condition(
my: jt.Float[jt.Array, "nt"],
iys: jt.Int[jt.Array, "nt"],
x_preeq: jt.Float[jt.Array, "nx"],
mask_reinit: jt.Bool[jt.Array, "nx"],
x_reinit: jt.Float[jt.Array, "nx"],
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
adjoint: diffrax.AbstractAdjoint,
Expand Down Expand Up @@ -482,47 +484,46 @@ def simulate_condition(
:return:
output according to `ret` and statistics
"""
# Pre-equilibration
if x_preeq.shape[0] > 0:
current_x = self._x_solver(x_preeq)
# update tcl with new parameters
tcl = self._tcl(x_preeq, p)
x = x_preeq

Check warning on line 488 in python/sdist/amici/jax/model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L487-L488

Added lines #L487 - L488 were not covered by tests
else:
x0 = self._x0(p)
current_x = self._x_solver(x0)
x = self._x0(p)

Check warning on line 490 in python/sdist/amici/jax/model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L490

Added line #L490 was not covered by tests

x = jnp.where(mask_reinit, x_reinit, x)
x_solver = self._x_solver(x)
tcl = self._tcl(x, p)

Check warning on line 494 in python/sdist/amici/jax/model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L492-L494

Added lines #L492 - L494 were not covered by tests

tcl = self._tcl(x0, p)
x_preq = jnp.repeat(current_x.reshape(1, -1), ts_init.shape[0], axis=0)
x_preq = jnp.repeat(x_solver.reshape(1, -1), ts_init.shape[0], axis=0)

Check warning on line 496 in python/sdist/amici/jax/model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L496

Added line #L496 was not covered by tests

# Dynamic simulation
if ts_dyn.shape[0] > 0:
x_dyn, stats_dyn = self._solve(
p,
ts_dyn,
tcl,
current_x,
x_solver,
solver,
controller,
max_steps,
adjoint,
)
current_x = x_dyn[-1, :]
x_solver = x_dyn[-1, :]

Check warning on line 510 in python/sdist/amici/jax/model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L510

Added line #L510 was not covered by tests
else:
x_dyn = jnp.repeat(
current_x.reshape(1, -1), ts_dyn.shape[0], axis=0
x_solver.reshape(1, -1), ts_dyn.shape[0], axis=0
)
stats_dyn = None

# Post-equilibration
if ts_posteq.shape[0] > 0:
current_x, stats_posteq = self._eq(
p, tcl, current_x, solver, controller, max_steps
x_solver, stats_posteq = self._eq(

Check warning on line 519 in python/sdist/amici/jax/model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L519

Added line #L519 was not covered by tests
p, tcl, x_solver, solver, controller, max_steps
)
else:
stats_posteq = None

x_posteq = jnp.repeat(
current_x.reshape(1, -1), ts_posteq.shape[0], axis=0
x_solver.reshape(1, -1), ts_posteq.shape[0], axis=0
)

ts = jnp.concatenate((ts_init, ts_dyn, ts_posteq), axis=0)

Check warning on line 529 in python/sdist/amici/jax/model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L529

Added line #L529 was not covered by tests
Expand Down
69 changes: 69 additions & 0 deletions python/sdist/amici/jax/petab.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,70 @@ def load_parameters(
)
return self._unscale(p, pscale)

def load_reinitialisation(

Check warning on line 295 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L295

Added line #L295 was not covered by tests
self,
simulation_condition: str,
p: jt.Float[jt.Array, "np"],
) -> tuple[jt.Bool[jt.Array, "nx"], jt.Float[jt.Array, "nx"]]: # noqa: F821
"""
Load reinitialisation values and mask for the state vector for a simulation condition.
:param simulation_condition:
Simulation condition to load reinitialisation for.
:param p:
Parameters for the simulation condition.
:return:
Tuple of reinitialisation masm and value for states.
"""
mask = jax.lax.stop_gradient(

Check warning on line 310 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L310

Added line #L310 was not covered by tests
jnp.array(
[
xname in self._petab_problem.condition_df
and not (
isinstance(
self._petab_problem.condition_df.loc[
simulation_condition, xname
],
Number,
)
and np.isnan(
self._petab_problem.condition_df.loc[
simulation_condition, xname
]
)
)
for xname in self.model.state_ids
]
)
)
reinit_x = jnp.array(

Check warning on line 331 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L331

Added line #L331 was not covered by tests
[
0
if xname not in self._petab_problem.condition_df
or (
isinstance(
xval := self._petab_problem.condition_df.loc[
simulation_condition, xname
],
Number,
)
and np.isnan(xval)
)
else xval
if isinstance(
xval := self._petab_problem.condition_df.loc[
simulation_condition, xname
],
Number,
)
else p
if xval in self.model.parameter_ids
else self.get_petab_parameter_by_id(xval)
for xname in self.model.state_ids
]
)
return mask, reinit_x

Check warning on line 357 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L357

Added line #L357 was not covered by tests

def update_parameters(self, p: jt.Float[jt.Array, "np"]) -> "JAXProblem":
"""
Update parameters for the model.
Expand Down Expand Up @@ -329,6 +393,9 @@ def run_simulation(
simulation_condition
]
p = self.load_parameters(simulation_condition[0])
mask_reinit, x_reinit = self.load_reinitialisation(

Check warning on line 396 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L396

Added line #L396 was not covered by tests
simulation_condition[0], p
)
return self.model.simulate_condition(
p=p,
ts_init=jax.lax.stop_gradient(jnp.array(ts_preeq)),
Expand All @@ -337,6 +404,8 @@ def run_simulation(
my=jax.lax.stop_gradient(jnp.array(my)),
iys=jax.lax.stop_gradient(jnp.array(iys)),
x_preeq=x_preeq,
mask_reinit=mask_reinit,
x_reinit=x_reinit,
solver=solver,
controller=controller,
max_steps=max_steps,
Expand Down
13 changes: 8 additions & 5 deletions python/sdist/amici/petab/sbml_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,11 +348,14 @@ def import_model_sbml(
_workaround_observable_parameters(
observables, sigmas, sbml_model, output_parameter_defaults
)
fixed_parameters = _workaround_initial_states(
petab_problem=petab_problem,
sbml_model=sbml_model,
**kwargs,
)
if not jax:
fixed_parameters = _workaround_initial_states(
petab_problem=petab_problem,
sbml_model=sbml_model,
**kwargs,
)
else:
fixed_parameters = []

fixed_parameters.extend(
_get_fixed_parameters_sbml(
Expand Down

0 comments on commit ee84559

Please sign in to comment.