diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index 8f9650ef0f..0a03e95751 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -433,6 +433,8 @@ def simulate_condition( my: jt.Float[jt.Array, "nt"], iys: jt.Int[jt.Array, "nt"], x_preeq: jt.Float[jt.Array, "nx"], + mask_reinit: jt.Bool[jt.Array, "nx"], + x_reinit: jt.Float[jt.Array, "nx"], solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, adjoint: diffrax.AbstractAdjoint, @@ -482,17 +484,16 @@ def simulate_condition( :return: output according to `ret` and statistics """ - # Pre-equilibration if x_preeq.shape[0] > 0: - current_x = self._x_solver(x_preeq) - # update tcl with new parameters - tcl = self._tcl(x_preeq, p) + x = x_preeq else: - x0 = self._x0(p) - current_x = self._x_solver(x0) + x = self._x0(p) + + x = jnp.where(mask_reinit, x_reinit, x) + x_solver = self._x_solver(x) + tcl = self._tcl(x, p) - tcl = self._tcl(x0, p) - x_preq = jnp.repeat(current_x.reshape(1, -1), ts_init.shape[0], axis=0) + x_preq = jnp.repeat(x_solver.reshape(1, -1), ts_init.shape[0], axis=0) # Dynamic simulation if ts_dyn.shape[0] > 0: @@ -500,29 +501,29 @@ def simulate_condition( p, ts_dyn, tcl, - current_x, + x_solver, solver, controller, max_steps, adjoint, ) - current_x = x_dyn[-1, :] + x_solver = x_dyn[-1, :] else: x_dyn = jnp.repeat( - current_x.reshape(1, -1), ts_dyn.shape[0], axis=0 + x_solver.reshape(1, -1), ts_dyn.shape[0], axis=0 ) stats_dyn = None # Post-equilibration if ts_posteq.shape[0] > 0: - current_x, stats_posteq = self._eq( - p, tcl, current_x, solver, controller, max_steps + x_solver, stats_posteq = self._eq( + p, tcl, x_solver, solver, controller, max_steps ) else: stats_posteq = None x_posteq = jnp.repeat( - current_x.reshape(1, -1), ts_posteq.shape[0], axis=0 + x_solver.reshape(1, -1), ts_posteq.shape[0], axis=0 ) ts = jnp.concatenate((ts_init, ts_dyn, ts_posteq), axis=0) diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index bb8749e27c..50c0154ee3 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -292,6 +292,70 @@ def load_parameters( ) return self._unscale(p, pscale) + def load_reinitialisation( + self, + simulation_condition: str, + p: jt.Float[jt.Array, "np"], + ) -> tuple[jt.Bool[jt.Array, "nx"], jt.Float[jt.Array, "nx"]]: # noqa: F821 + """ + Load reinitialisation values and mask for the state vector for a simulation condition. + + :param simulation_condition: + Simulation condition to load reinitialisation for. + :param p: + Parameters for the simulation condition. + :return: + Tuple of reinitialisation masm and value for states. + """ + mask = jax.lax.stop_gradient( + jnp.array( + [ + xname in self._petab_problem.condition_df + and not ( + isinstance( + self._petab_problem.condition_df.loc[ + simulation_condition, xname + ], + Number, + ) + and np.isnan( + self._petab_problem.condition_df.loc[ + simulation_condition, xname + ] + ) + ) + for xname in self.model.state_ids + ] + ) + ) + reinit_x = jnp.array( + [ + 0 + if xname not in self._petab_problem.condition_df + or ( + isinstance( + xval := self._petab_problem.condition_df.loc[ + simulation_condition, xname + ], + Number, + ) + and np.isnan(xval) + ) + else xval + if isinstance( + xval := self._petab_problem.condition_df.loc[ + simulation_condition, xname + ], + Number, + ) + else p + if xval in self.model.parameter_ids + else self.get_petab_parameter_by_id(xval) + for xname in self.model.state_ids + ] + ) + return mask, reinit_x + def update_parameters(self, p: jt.Float[jt.Array, "np"]) -> "JAXProblem": """ Update parameters for the model. @@ -329,6 +393,9 @@ def run_simulation( simulation_condition ] p = self.load_parameters(simulation_condition[0]) + mask_reinit, x_reinit = self.load_reinitialisation( + simulation_condition[0], p + ) return self.model.simulate_condition( p=p, ts_init=jax.lax.stop_gradient(jnp.array(ts_preeq)), @@ -337,6 +404,8 @@ def run_simulation( my=jax.lax.stop_gradient(jnp.array(my)), iys=jax.lax.stop_gradient(jnp.array(iys)), x_preeq=x_preeq, + mask_reinit=mask_reinit, + x_reinit=x_reinit, solver=solver, controller=controller, max_steps=max_steps, diff --git a/python/sdist/amici/petab/sbml_import.py b/python/sdist/amici/petab/sbml_import.py index e4e7efd7fc..e605a9cc80 100644 --- a/python/sdist/amici/petab/sbml_import.py +++ b/python/sdist/amici/petab/sbml_import.py @@ -348,11 +348,14 @@ def import_model_sbml( _workaround_observable_parameters( observables, sigmas, sbml_model, output_parameter_defaults ) - fixed_parameters = _workaround_initial_states( - petab_problem=petab_problem, - sbml_model=sbml_model, - **kwargs, - ) + if not jax: + fixed_parameters = _workaround_initial_states( + petab_problem=petab_problem, + sbml_model=sbml_model, + **kwargs, + ) + else: + fixed_parameters = [] fixed_parameters.extend( _get_fixed_parameters_sbml(