diff --git a/docs/batching.ipynb b/docs/batching.ipynb index a539935..a013dbf 100644 --- a/docs/batching.ipynb +++ b/docs/batching.ipynb @@ -132,7 +132,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -148,18 +148,21 @@ "Finally we are ready to run the optimisation. We define a simple loss function that \n", "simply takes the sum over the energy for each molecular conformation in the batch.\n", "In simple terms, the optimisation will follow the gradient to make this sum as small as\n", - "possible." + "possible. We use the transformation\n", + "[jax.value_and_grad](https://jax.readthedocs.io/en/latest/_autosummary/jax.value_and_grad.html)\n", + "as a function decorator on this loss function to evaluate the loss and the corresponding\n", + "gradient. Note that the gradient is computed using automatic differentiation." ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "3eeffa13512946c683c9788d956cbc1b", + "model_id": "1dd410e1939246b4825d14cfa1ae0296", "version_major": 2, "version_minor": 0 }, @@ -172,6 +175,7 @@ } ], "source": [ + "@jax.value_and_grad\n", "def loss_fn(z, h):\n", " return jnp.sum(energy(z, h))\n", "\n", @@ -179,7 +183,7 @@ "history = []\n", "\n", "for _ in (bar := tqdm(range(128))):\n", - " loss, grads = jax.value_and_grad(loss_fn)(Z, H)\n", + " loss, grads = loss_fn(Z, H)\n", " updates, state = optimiser.update(grads, state)\n", " Z = optax.apply_updates(Z, updates)\n", " history.append(loss)\n", @@ -195,7 +199,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -230,7 +234,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 8, "metadata": {}, "outputs": [ {