From df2dc1bbcde05971b4b4bac762182609be0f93cf Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Fri, 26 Apr 2024 15:27:35 +0000 Subject: [PATCH] adding autodiff text --- docs/batching.ipynb | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/docs/batching.ipynb b/docs/batching.ipynb index a539935..a013dbf 100644 --- a/docs/batching.ipynb +++ b/docs/batching.ipynb @@ -132,7 +132,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -148,18 +148,21 @@ "Finally we are ready to run the optimisation. We define a simple loss function that \n", "simply takes the sum over the energy for each molecular conformation in the batch.\n", "In simple terms, the optimisation will follow the gradient to make this sum as small as\n", - "possible." + "possible. We use the transformation\n", + "[jax.value_and_grad](https://jax.readthedocs.io/en/latest/_autosummary/jax.value_and_grad.html)\n", + "as a function decorator on this loss function to evaluate the loss and the corresponding\n", + "gradient. Note that the gradient is computed using automatic differentiation." ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "3eeffa13512946c683c9788d956cbc1b", + "model_id": "1dd410e1939246b4825d14cfa1ae0296", "version_major": 2, "version_minor": 0 }, @@ -172,6 +175,7 @@ } ], "source": [ + "@jax.value_and_grad\n", "def loss_fn(z, h):\n", " return jnp.sum(energy(z, h))\n", "\n", @@ -179,7 +183,7 @@ "history = []\n", "\n", "for _ in (bar := tqdm(range(128))):\n", - " loss, grads = jax.value_and_grad(loss_fn)(Z, H)\n", + " loss, grads = loss_fn(Z, H)\n", " updates, state = optimiser.update(grads, state)\n", " Z = optax.apply_updates(Z, updates)\n", " history.append(loss)\n", @@ -195,7 +199,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -230,7 +234,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 8, "metadata": {}, "outputs": [ {