diff --git a/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb b/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb index 1310091f4c..6d645a1451 100644 --- a/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb +++ b/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb @@ -25,10 +25,11 @@ ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "c71c96da0da3144a", + "metadata": {}, + "outputs": [], "source": [ "from amici.petab.petab_import import import_petab_problem\n", "import petab.v1 as petab\n", @@ -49,24 +50,24 @@ " verbose=False, # no text output\n", " jax=True, # return jax model\n", ")" - ], - "id": "c71c96da0da3144a" + ] }, { - "metadata": {}, "cell_type": "markdown", + "id": "7e0f1c27bd71ee1f", + "metadata": {}, "source": [ "## Simulation\n", "\n", "In principle, we can already use this model for simulation using the [simulate_condition](https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.JAXModel.simulate_condition) method. However, this approach can be cumbersome as timepoints, data etc. need to be specified manually. Instead, we process the PEtab problem into a [JAXProblem](https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.JAXProblem), which enables efficient simulation using [amici.jax.run_simulations]((https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.run_simulations)." - ], - "id": "7e0f1c27bd71ee1f" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "ccecc9a29acc7b73", + "metadata": {}, + "outputs": [], "source": [ "from amici.jax import JAXProblem, run_simulations\n", "\n", @@ -75,44 +76,44 @@ "\n", "# Run simulations and compute the log-likelihood\n", "llh, results = run_simulations(jax_problem)" - ], - "id": "ccecc9a29acc7b73" + ] }, { - "metadata": {}, "cell_type": "markdown", - "source": "This simulates the model for all conditions using the nominal parameter values. Simple, right? Now, let’s take a look at the simulation results.", - "id": "415962751301c64a" + "id": "415962751301c64a", + "metadata": {}, + "source": "This simulates the model for all conditions using the nominal parameter values. Simple, right? Now, let’s take a look at the simulation results." }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "596b86e45e18fe3d", + "metadata": {}, + "outputs": [], "source": [ "# Define the simulation condition\n", "simulation_condition = (\"model1_data1\",)\n", "\n", "# Access the results for the specified condition\n", "results[simulation_condition]" - ], - "id": "596b86e45e18fe3d" + ] }, { - "metadata": {}, "cell_type": "markdown", + "id": "a1b173e013f9210a", + "metadata": {}, "source": [ "Unfortunately, the simulation failed! As seen in the output, the simulation broke down after the initial timepoint, indicated by the `inf` values in the state variables `results[simulation_condition][1].x` and the `nan` likelihood value. A closer inspection of this variable provides additional clues about what might have gone wrong.\n", "\n", "The issue stems from using single precision, as indicated by the `float32` dtype of state variables. Single precision is generally a [bad idea](https://docs.kidger.site/diffrax/examples/stiff_ode/) for stiff systems like the Böhm model. Let’s retry the simulation with double precision." - ], - "id": "a1b173e013f9210a" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "f4f5ff705a3f7402", + "metadata": {}, + "outputs": [], "source": [ "import jax\n", "\n", @@ -123,20 +124,20 @@ "llh, results = run_simulations(jax_problem)\n", "\n", "results" - ], - "id": "f4f5ff705a3f7402" + ] }, { - "metadata": {}, "cell_type": "markdown", - "source": "Success! The simulation completed successfully, and we can now plot the resulting state trajectories.", - "id": "fe4d3b40ee3efdf2" + "id": "fe4d3b40ee3efdf2", + "metadata": {}, + "source": "Success! The simulation completed successfully, and we can now plot the resulting state trajectories." }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "72f1ed397105e14a", + "metadata": {}, + "outputs": [], "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", @@ -171,41 +172,41 @@ "\n", "# Plot the simulation results\n", "plot_simulation(results)" - ], - "id": "72f1ed397105e14a" + ] }, { - "metadata": {}, "cell_type": "markdown", - "source": "`run_simulations` enables users to specify the simulation conditions to be executed. For more complex models, this allows for restricting simulations to a subset of conditions. Since the Böhm model includes only a single condition, we demonstrate this functionality by simulating no condition at all.", - "id": "4fa97c33719c2277" + "id": "4fa97c33719c2277", + "metadata": {}, + "source": "`run_simulations` enables users to specify the simulation conditions to be executed. For more complex models, this allows for restricting simulations to a subset of conditions. Since the Böhm model includes only a single condition, we demonstrate this functionality by simulating no condition at all." }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "7950774a3e989042", + "metadata": {}, + "outputs": [], "source": [ "llh, results = run_simulations(jax_problem, simulation_conditions=tuple())\n", "results" - ], - "id": "7950774a3e989042" + ] }, { - "metadata": {}, "cell_type": "markdown", + "id": "98b8516a75ce4d12", + "metadata": {}, "source": [ "## Updating Parameters\n", "\n", "As next step, we will update the parameter values used for simulation. However, if we attempt to directly modify the values in `JAXModel.parameters`, we encounter a `FrozenInstanceError`." - ], - "id": "98b8516a75ce4d12" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "3d278a3d21e709d", + "metadata": {}, + "outputs": [], "source": [ "from dataclasses import FrozenInstanceError\n", "import jax\n", @@ -223,24 +224,24 @@ " jax_problem.parameters += noise\n", "except FrozenInstanceError as e:\n", " print(\"Error:\", e)" - ], - "id": "3d278a3d21e709d" + ] }, { - "metadata": {}, "cell_type": "markdown", + "id": "4cc3d595de4a4085", + "metadata": {}, "source": [ "The root cause of this error lies in the fact that, to enable autodiff, direct modifications of attributes are not allowed in [equinox](https://docs.kidger.site/equinox/), which AMICI utilizes under the hood. Consequently, attributes of instances like `JAXModel` or `JAXProblem` cannot be updated directly — this is the price we have to pay for autodiff.\n", "\n", "However, `JAXProblem` provides a convenient method called [update_parameters](https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.JAXProblem.update_parameters). The caveat is that this method creates a new JAXProblem instance instead of modifying the existing one." - ], - "id": "4cc3d595de4a4085" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "e47748376059628b", + "metadata": {}, + "outputs": [], "source": [ "# Update the parameters and create a new JAXProblem instance\n", "jax_problem = jax_problem.update_parameters(jax_problem.parameters + noise)\n", @@ -250,105 +251,111 @@ "\n", "# Plot the simulation results\n", "plot_simulation(results)" - ], - "id": "e47748376059628b" + ] }, { - "metadata": {}, "cell_type": "markdown", + "id": "660baf605a4e8339", + "metadata": {}, "source": [ "## Computing Gradients\n", "\n", "Similar to updating attributes, computing gradients in the JAX ecosystem can feel a bit unconventional if you’re not familiar with the JAX ecosysmt. JAX offers [powerful automatic differentiation](https://jax.readthedocs.io/en/latest/automatic-differentiation.html) through the `jax.grad` function. However, to use `jax.grad` with `JAXProblem`, we need to specify which parts of the `JAXProblem` should be treated as static." - ], - "id": "660baf605a4e8339" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "7033d09cc81b7f69", + "metadata": {}, + "outputs": [], "source": [ "try:\n", " # Attempt to compute the gradient of the run_simulations function\n", " jax.grad(run_simulations, has_aux=True)(jax_problem)\n", "except TypeError as e:\n", " print(\"Error:\", e)" - ], - "id": "7033d09cc81b7f69" + ] }, { - "metadata": {}, "cell_type": "markdown", - "source": "Fortunately, `equinox` simplifies this process by offering [filter_grad](https://docs.kidger.site/equinox/api/transformations/#equinox.filter_grad), which enables autodiff functionality that is compatible with `JAXProblem` and, in theory, also with `JAXModel`.", - "id": "dc9bc07cde00a926" + "id": "dc9bc07cde00a926", + "metadata": {}, + "source": "Fortunately, `equinox` simplifies this process by offering [filter_grad](https://docs.kidger.site/equinox/api/transformations/#equinox.filter_grad), which enables autodiff functionality that is compatible with `JAXProblem` and, in theory, also with `JAXModel`." }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "a6704182200e6438", + "metadata": {}, + "outputs": [], "source": [ "import equinox as eqx\n", "\n", "# Compute the gradient using equinox's filter_grad, preserving auxiliary outputs\n", "grad, _ = eqx.filter_grad(run_simulations, has_aux=True)(jax_problem)" - ], - "id": "a6704182200e6438" + ] }, { - "metadata": {}, "cell_type": "markdown", - "source": "Functions transformed by `filter_grad` return gradients that share the same structure as the first argument (unless specified otherwise). This allows us to access the gradient with respect to the parameters attribute directly `via grad.parameters`.", - "id": "851c3ec94cb5d086" + "id": "851c3ec94cb5d086", + "metadata": {}, + "source": "Functions transformed by `filter_grad` return gradients that share the same structure as the first argument (unless specified otherwise). This allows us to access the gradient with respect to the parameters attribute directly `via grad.parameters`." }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, - "source": "grad.parameters", - "id": "c00c1581d7173d7a" + "id": "c00c1581d7173d7a", + "metadata": {}, + "outputs": [], + "source": [ + "grad.parameters" + ] }, { - "metadata": {}, "cell_type": "markdown", - "source": "Attributes for which derivatives cannot be computed (typically anything that is not a [jax.numpy.array](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html)) are automatically set to `None`.", - "id": "375b835fecc5a022" + "id": "375b835fecc5a022", + "metadata": {}, + "source": "Attributes for which derivatives cannot be computed (typically anything that is not a [jax.numpy.array](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html)) are automatically set to `None`." }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, - "source": "grad", - "id": "f7c17f7459d0151f" + "id": "f7c17f7459d0151f", + "metadata": {}, + "outputs": [], + "source": [ + "grad" + ] }, { - "metadata": {}, "cell_type": "markdown", - "source": "Observant readers may notice that the gradient above appears to include numeric values for derivatives with respect to some measurements. However, `simulation_conditions` internally disables gradient computations using `jax.lax.stop_gradient`, resulting in these values being zeroed out.", - "id": "8eb7cc3db510c826" + "id": "8eb7cc3db510c826", + "metadata": {}, + "source": "Observant readers may notice that the gradient above appears to include numeric values for derivatives with respect to some measurements. However, `simulation_conditions` internally disables gradient computations using `jax.lax.stop_gradient`, resulting in these values being zeroed out." }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, - "source": "grad._measurements[simulation_condition]", - "id": "3badd4402cf6b8c6" + "id": "3badd4402cf6b8c6", + "metadata": {}, + "outputs": [], + "source": [ + "grad._measurements[simulation_condition]" + ] }, { - "metadata": {}, "cell_type": "markdown", - "source": "However, we can compute derivatives with respect to data elements using `JAXModel.simulate_condition`. In the example below, we differentiate the observables `y` (specified by passing `y` to the `ret` argument) with respect to the timepoints at which the model outputs are computed after the solving the differential equation. While this might not be particularly practical, it serves as an nice illustration of the power of automatic differentiation.", - "id": "58eb04393a1463d" + "id": "58eb04393a1463d", + "metadata": {}, + "source": "However, we can compute derivatives with respect to data elements using `JAXModel.simulate_condition`. In the example below, we differentiate the observables `y` (specified by passing `y` to the `ret` argument) with respect to the timepoints at which the model outputs are computed after the solving the differential equation. While this might not be particularly practical, it serves as an nice illustration of the power of automatic differentiation." }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "1a91aff44b93157", + "metadata": {}, + "outputs": [], "source": [ "import jax.numpy as jnp\n", "import diffrax\n", @@ -363,7 +370,7 @@ "]\n", "\n", "# Load parameters for the specified condition\n", - "p = jax_problem.load_parameters(simulation_condition[0])\n", + "p = jax_problem.load_model_parameters(simulation_condition[0])\n", "\n", "\n", "# Define a function to compute the gradient with respect to dynamic timepoints\n", @@ -388,24 +395,24 @@ "# Compute the gradient with respect to `ts_dyn`\n", "g = grad_ts_dyn(ts_dyn)\n", "g" - ], - "id": "1a91aff44b93157" + ] }, { - "metadata": {}, "cell_type": "markdown", + "id": "9f870da7754e139c", + "metadata": {}, "source": [ "## Compilation & Profiling\n", "\n", "To maximize performance with JAX, code should be just-in-time (JIT) compiled. This can be achieved using the `jax.jit` or `equinox.filter_jit` decorators. While JIT compilation introduces some overhead during the first function call, it significantly improves performance for subsequent calls. To demonstrate this, we will first clear the JIT cache and then profile the execution." - ], - "id": "9f870da7754e139c" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "58ebdc110ea7457e", + "metadata": {}, + "outputs": [], "source": [ "from time import time\n", "\n", @@ -414,14 +421,14 @@ "\n", "# Define a JIT-compiled gradient function with auxiliary outputs\n", "gradfun = eqx.filter_jit(eqx.filter_grad(run_simulations, has_aux=True))" - ], - "id": "58ebdc110ea7457e" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "e1242075f7e0faf", + "metadata": {}, + "outputs": [], "source": [ "# Measure the time taken for the first function call (including compilation)\n", "start = time()\n", @@ -432,14 +439,14 @@ "start = time()\n", "gradfun(jax_problem)\n", "print(f\"Gradient compilation time: {time() - start:.2f} seconds\")" - ], - "id": "e1242075f7e0faf" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "27181f367ccb1817", + "metadata": {}, + "outputs": [], "source": [ "%%timeit\n", "run_simulations(\n", @@ -452,14 +459,14 @@ " dcoeff=0.0, # recommended value for stiff systems\n", " ),\n", ")" - ], - "id": "27181f367ccb1817" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "5b8d3a6162a3ae55", + "metadata": {}, + "outputs": [], "source": [ "%%timeit \n", "gradfun(\n", @@ -472,14 +479,14 @@ " dcoeff=0.0, # recommended value for stiff systems\n", " ),\n", ")" - ], - "id": "5b8d3a6162a3ae55" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "d733a450635a749b", + "metadata": {}, + "outputs": [], "source": [ "from amici.petab import simulate_petab\n", "import amici\n", @@ -500,8 +507,7 @@ "problem_parameters = dict(\n", " zip(jax_problem.parameter_ids, jax_problem.parameters)\n", ")" - ], - "id": "d733a450635a749b" + ] }, { "cell_type": "code",