diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index b5834223fb..b2b42f5c2a 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -500,7 +500,7 @@ def run_simulation( simulation_condition[0], p ) return self.model.simulate_condition( - p=eqx.debug.backward_nan(p), + p=p, ts_init=jax.lax.stop_gradient(jnp.array(ts_preeq)), ts_dyn=jax.lax.stop_gradient(jnp.array(ts_dyn)), ts_posteq=jax.lax.stop_gradient(jnp.array(ts_posteq)),