Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

State reinitialisation in JAX #2619

Merged
merged 29 commits into from
Dec 8, 2024
Merged
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
implement jax-based reinitialisation
FFroehlich committed Dec 5, 2024
commit ee84559fafedcbb695529ab3dbf28e1f8bc71cce
29 changes: 15 additions & 14 deletions python/sdist/amici/jax/model.py
Original file line number Diff line number Diff line change
@@ -433,6 +433,8 @@
my: jt.Float[jt.Array, "nt"],
iys: jt.Int[jt.Array, "nt"],
x_preeq: jt.Float[jt.Array, "nx"],
mask_reinit: jt.Bool[jt.Array, "nx"],
x_reinit: jt.Float[jt.Array, "nx"],
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
adjoint: diffrax.AbstractAdjoint,
@@ -482,50 +484,49 @@
:return:
output according to `ret` and statistics
"""
# Pre-equilibration
if x_preeq.shape[0] > 0:
current_x = self._x_solver(x_preeq)
# update tcl with new parameters
tcl = self._tcl(x_preeq, p)
x = x_preeq

Check warning on line 488 in python/sdist/amici/jax/model.py

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L487-L488

Added lines #L487 - L488 were not covered by tests
else:
x0 = self._x0(p)
current_x = self._x_solver(x0)
x = self._x0(p)

Check warning on line 490 in python/sdist/amici/jax/model.py

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L490

Added line #L490 was not covered by tests

x = jnp.where(mask_reinit, x_reinit, x)
x_solver = self._x_solver(x)
tcl = self._tcl(x, p)

Check warning on line 494 in python/sdist/amici/jax/model.py

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L492-L494

Added lines #L492 - L494 were not covered by tests

tcl = self._tcl(x0, p)
x_preq = jnp.repeat(current_x.reshape(1, -1), ts_init.shape[0], axis=0)
x_preq = jnp.repeat(x_solver.reshape(1, -1), ts_init.shape[0], axis=0)

Check warning on line 496 in python/sdist/amici/jax/model.py

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L496

Added line #L496 was not covered by tests

# Dynamic simulation
if ts_dyn.shape[0] > 0:
x_dyn, stats_dyn = self._solve(
p,
ts_dyn,
tcl,
current_x,
x_solver,
solver,
controller,
max_steps,
adjoint,
)
current_x = x_dyn[-1, :]
x_solver = x_dyn[-1, :]

Check warning on line 510 in python/sdist/amici/jax/model.py

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L510

Added line #L510 was not covered by tests
else:
x_dyn = jnp.repeat(
current_x.reshape(1, -1), ts_dyn.shape[0], axis=0
x_solver.reshape(1, -1), ts_dyn.shape[0], axis=0
)
stats_dyn = None

# Post-equilibration
if ts_posteq.shape[0] > 0:
current_x, stats_posteq = self._eq(
p, tcl, current_x, solver, controller, max_steps
x_solver, stats_posteq = self._eq(

Check warning on line 519 in python/sdist/amici/jax/model.py

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L519

Added line #L519 was not covered by tests
p, tcl, x_solver, solver, controller, max_steps
)
else:
stats_posteq = None

x_posteq = jnp.repeat(
current_x.reshape(1, -1), ts_posteq.shape[0], axis=0
x_solver.reshape(1, -1), ts_posteq.shape[0], axis=0
)

ts = jnp.concatenate((ts_init, ts_dyn, ts_posteq), axis=0)

Check warning on line 529 in python/sdist/amici/jax/model.py

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L529

Added line #L529 was not covered by tests
x = jnp.concatenate((x_preq, x_dyn, x_posteq), axis=0)

nllhs = self._nllhs(ts, x, p, tcl, my, iys)
@@ -548,8 +549,8 @@
stats_posteq=stats_posteq,
)

@eqx.filter_jit
def preequilibrate_condition(

Check warning on line 553 in python/sdist/amici/jax/model.py

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L552-L553

Added lines #L552 - L553 were not covered by tests
self,
p: jt.Float[jt.Array, "np"],
solver: diffrax.AbstractSolver,
@@ -571,14 +572,14 @@
pre-equilibrated state variables and statistics
"""
# Pre-equilibration
x0 = self._x0(p)
tcl = self._tcl(x0, p)
current_x = self._x_solver(x0)
current_x, stats_preeq = self._eq(

Check warning on line 578 in python/sdist/amici/jax/model.py

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L575-L578

Added lines #L575 - L578 were not covered by tests
p, tcl, current_x, solver, controller, max_steps
)

return self._x_rdata(current_x, tcl), dict(stats_preeq=stats_preeq)

Check warning on line 582 in python/sdist/amici/jax/model.py

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L582

Added line #L582 was not covered by tests


def safe_log(x: jnp.float_) -> jnp.float_:
69 changes: 69 additions & 0 deletions python/sdist/amici/jax/petab.py
Original file line number Diff line number Diff line change
@@ -292,6 +292,70 @@
)
return self._unscale(p, pscale)

def load_reinitialisation(

Check warning on line 295 in python/sdist/amici/jax/petab.py

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L295

Added line #L295 was not covered by tests
self,
simulation_condition: str,
p: jt.Float[jt.Array, "np"],
) -> tuple[jt.Bool[jt.Array, "nx"], jt.Float[jt.Array, "nx"]]: # noqa: F821
"""
Load reinitialisation values and mask for the state vector for a simulation condition.

:param simulation_condition:
Simulation condition to load reinitialisation for.
:param p:
Parameters for the simulation condition.
:return:
Tuple of reinitialisation masm and value for states.
"""
mask = jax.lax.stop_gradient(

Check warning on line 310 in python/sdist/amici/jax/petab.py

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L310

Added line #L310 was not covered by tests
jnp.array(
[
xname in self._petab_problem.condition_df
and not (
isinstance(
self._petab_problem.condition_df.loc[
simulation_condition, xname
],
Number,
)
and np.isnan(
self._petab_problem.condition_df.loc[
simulation_condition, xname
]
)
)
for xname in self.model.state_ids
]
)
)
reinit_x = jnp.array(

Check warning on line 331 in python/sdist/amici/jax/petab.py

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L331

Added line #L331 was not covered by tests
[
0
if xname not in self._petab_problem.condition_df
or (
isinstance(
xval := self._petab_problem.condition_df.loc[
simulation_condition, xname
],
Number,
)
and np.isnan(xval)
)
else xval
if isinstance(
xval := self._petab_problem.condition_df.loc[
simulation_condition, xname
],
Number,
)
else p
if xval in self.model.parameter_ids
else self.get_petab_parameter_by_id(xval)
for xname in self.model.state_ids
]
)
return mask, reinit_x

Check warning on line 357 in python/sdist/amici/jax/petab.py

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L357

Added line #L357 was not covered by tests

def update_parameters(self, p: jt.Float[jt.Array, "np"]) -> "JAXProblem":
"""
Update parameters for the model.
@@ -329,6 +393,9 @@
simulation_condition
]
p = self.load_parameters(simulation_condition[0])
mask_reinit, x_reinit = self.load_reinitialisation(

Check warning on line 396 in python/sdist/amici/jax/petab.py

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L396

Added line #L396 was not covered by tests
simulation_condition[0], p
)
return self.model.simulate_condition(
p=p,
ts_init=jax.lax.stop_gradient(jnp.array(ts_preeq)),
@@ -337,13 +404,15 @@
my=jax.lax.stop_gradient(jnp.array(my)),
iys=jax.lax.stop_gradient(jnp.array(iys)),
x_preeq=x_preeq,
mask_reinit=mask_reinit,
x_reinit=x_reinit,
solver=solver,
controller=controller,
max_steps=max_steps,
adjoint=diffrax.RecursiveCheckpointAdjoint(),
)

def run_preequilibration(

Check warning on line 415 in python/sdist/amici/jax/petab.py

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L415

Added line #L415 was not covered by tests
self,
simulation_condition: str,
solver: diffrax.AbstractSolver,
@@ -364,8 +433,8 @@
:return:
Pre-equilibration state
"""
p = self.load_parameters(simulation_condition)
return self.model.preequilibrate_condition(

Check warning on line 437 in python/sdist/amici/jax/petab.py

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L436-L437

Added lines #L436 - L437 were not covered by tests
p=p,
solver=solver,
controller=controller,
@@ -405,7 +474,7 @@
if simulation_conditions is None:
simulation_conditions = problem.get_all_simulation_conditions()

preeqs = {

Check warning on line 477 in python/sdist/amici/jax/petab.py

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L477

Added line #L477 was not covered by tests
sc: problem.run_preequilibration(sc, solver, controller, max_steps)
# only run preequilibration once per condition
for sc in {sc[1] for sc in simulation_conditions if len(sc) > 1}
@@ -421,4 +490,4 @@
)
for sc in simulation_conditions
}
return sum(llh for llh, _ in results.values()), results | preeqs

Check warning on line 493 in python/sdist/amici/jax/petab.py

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L493

Added line #L493 was not covered by tests
13 changes: 8 additions & 5 deletions python/sdist/amici/petab/sbml_import.py
Original file line number Diff line number Diff line change
@@ -348,11 +348,14 @@
_workaround_observable_parameters(
observables, sigmas, sbml_model, output_parameter_defaults
)
fixed_parameters = _workaround_initial_states(
petab_problem=petab_problem,
sbml_model=sbml_model,
**kwargs,
)
if not jax:
fixed_parameters = _workaround_initial_states(
petab_problem=petab_problem,
sbml_model=sbml_model,
**kwargs,
)
else:
fixed_parameters = []

fixed_parameters.extend(
_get_fixed_parameters_sbml(
@@ -607,7 +610,7 @@

# try sbml model id
if sbml_model_id := sbml_model.getId():
return BASE_DIR / (sbml_model_id + suffix)

Check warning on line 613 in python/sdist/amici/petab/sbml_import.py

Codecov / codecov/patch

python/sdist/amici/petab/sbml_import.py#L613

Added line #L613 was not covered by tests

# create random folder name
return Path(tempfile.mkdtemp(dir=BASE_DIR))

Unchanged files with check annotations Beta

output_dir = Path(os.getcwd()) / f"amici-{self.model_name}"
self.model_path = Path(output_dir).resolve()
self.model_path.mkdir(parents=True, exist_ok=True)

Check warning on line 259 in python/sdist/amici/jax/ode_export.py

Codecov / codecov/patch

python/sdist/amici/jax/ode_export.py#L259

Added line #L259 was not covered by tests
def set_name(self, model_name: str) -> None:
"""
:return:
"""
if jax:
outdir = Path(model_output_dir)
return outdir.stem, outdir.parent

Check warning on line 291 in python/sdist/amici/petab/import_helpers.py

Codecov / codecov/patch

python/sdist/amici/petab/import_helpers.py#L290-L291

Added lines #L290 - L291 were not covered by tests
else:
return model_name, Path(model_output_dir)