Skip to content
This repository has been archived by the owner on Sep 24, 2024. It is now read-only.

Commit

Permalink
adding autodiff text
Browse files Browse the repository at this point in the history
  • Loading branch information
hatemhelal committed Apr 26, 2024
1 parent 54e1a98 commit df2dc1b
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions docs/batching.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -148,18 +148,21 @@
"Finally we are ready to run the optimisation. We define a simple loss function that \n",
"simply takes the sum over the energy for each molecular conformation in the batch.\n",
"In simple terms, the optimisation will follow the gradient to make this sum as small as\n",
"possible."
"possible. We use the transformation\n",
"[jax.value_and_grad](https://jax.readthedocs.io/en/latest/_autosummary/jax.value_and_grad.html)\n",
"as a function decorator on this loss function to evaluate the loss and the corresponding\n",
"gradient. Note that the gradient is computed using automatic differentiation."
]
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3eeffa13512946c683c9788d956cbc1b",
"model_id": "1dd410e1939246b4825d14cfa1ae0296",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -172,14 +175,15 @@
}
],
"source": [
"@jax.value_and_grad\n",
"def loss_fn(z, h):\n",
" return jnp.sum(energy(z, h))\n",
"\n",
"\n",
"history = []\n",
"\n",
"for _ in (bar := tqdm(range(128))):\n",
" loss, grads = jax.value_and_grad(loss_fn)(Z, H)\n",
" loss, grads = loss_fn(Z, H)\n",
" updates, state = optimiser.update(grads, state)\n",
" Z = optax.apply_updates(Z, updates)\n",
" history.append(loss)\n",
Expand All @@ -195,7 +199,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 7,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -230,7 +234,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 8,
"metadata": {},
"outputs": [
{
Expand Down

0 comments on commit df2dc1b

Please sign in to comment.