diff --git a/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb b/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb index a77cf64c69..a937f76d47 100644 --- a/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb +++ b/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb @@ -357,14 +357,12 @@ "simulation_condition = (\"model1_data1\",)\n", "\n", "# Load condition-specific data\n", - "ts_preeq, ts_dyn, ts_posteq, my, iys = jax_problem._measurements[\n", + "ts_init, ts_dyn, ts_posteq, my, iys = jax_problem._measurements[\n", " simulation_condition\n", "]\n", "\n", "# Load parameters for the specified condition\n", "p = jax_problem.load_parameters(simulation_condition[0])\n", - "# Disable preequilibration\n", - "p_preeq = jnp.array([])\n", "\n", "\n", "# Define a function to compute the gradient with respect to dynamic timepoints\n", @@ -372,8 +370,7 @@ "def grad_ts_dyn(tt):\n", " return jax_problem.model.simulate_condition(\n", " p=p,\n", - " p_preeq=p_preeq,\n", - " ts_init=ts_preeq,\n", + " ts_init=ts_init,\n", " ts_dyn=tt,\n", " ts_posteq=ts_posteq,\n", " my=jnp.array(my),\n",