Skip to content
This repository has been archived by the owner on Sep 24, 2024. It is now read-only.

Commit

Permalink
expanding text in batching notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
hatemhelal committed Apr 25, 2024
1 parent b85d0cf commit cece2cf
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 8 deletions.
96 changes: 89 additions & 7 deletions docs/batching.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,39 @@
"# Copyright (c) 2024 Graphcore Ltd. All rights reserved."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Batching\n",
"\n",
"Training neural networks over batches of data is an essential component of efficiently\n",
"utilising massively parallel hardware accelerators. We can recast a standard electronic\n",
"minimisation problem as a batched one by using the JAX vectorising map\n",
"[aka jax.vmap](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax-vmap).\n",
"By doing this we can parallelise the electronic structure optimisation over multiple\n",
"conformations of the same molecule. Just like in training neural networks, this will \n",
"allow for more efficient utilisation of the accelerator which unlocks\n",
"exploring the \n",
"[potential energy surfaces](https://en.wikipedia.org/wiki/Potential_energy_surface) \n",
"using quantum-mechanical simulations.\n",
"\n",
"We demonstrate this idea by calculating the molecular Hydrogen dissociation curve using\n",
"a batch of hydrogen molecules where the bond length (H-H distance) is varied. To setup\n",
"the batch we build a `Hamiltonian` for each bond length and stack the built modules to \n",
"create a batched-Hamiltonian. This example uses the `sto-3g` basis set and the simple\n",
"local density approximation of density functional theory but the this formulation isn't\n",
"unique to these choices for how the Hamiltonian is represented."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"metadata": {
"tags": [
"hide-cell"
]
},
"outputs": [
{
"name": "stderr",
Expand Down Expand Up @@ -60,7 +89,19 @@
"rs = np.linspace(0.6, 6, num_confs)\n",
"H = [h2_hamiltonian(r) for r in rs]\n",
"num_orbitals = H[0].basis.num_orbitals\n",
"H = jax.tree_map(lambda *xs: jnp.sqtack(xs), *H)"
"H = jax.tree_map(lambda *xs: jnp.stack(xs), *H)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next is where the magic happens using the vectorising transformation from JAX. \n",
"The `energy` function evaluates the energy of a single `Hamiltonian` for the unconstrained\n",
"trial matrix $Z$. The `@jax.vmap` converts this simple function to work on the batched\n",
"Hamiltonian we constructed earlier. For an extra performance boost we use the \n",
"[jax.jit](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax-jit)\n",
"function transformation to compile this function."
]
},
{
Expand All @@ -77,9 +118,21 @@
" return H(P)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next we setup an initial guess and a gradient descent optimiser\n",
"([adam](https://optax.readthedocs.io/en/latest/api/optimizers.html#optax.adam) from the optax library)\n",
"that will solve the batched energy minimisation problem.\n",
"\n",
"The initial guess is somewhat arbitrary and there are certainly better methods that one\n",
"could use for initialisation."
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -88,15 +141,25 @@
"state = optimiser.init(Z)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally we are ready to run the optimisation. We define a simple loss function that \n",
"simply takes the sum over the energy for each molecular conformation in the batch.\n",
"In simple terms, the optimisation will follow the gradient to make this sum as small as\n",
"possible."
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0fc26eb61c9a494fae0bb13ddcde1548",
"model_id": "3eeffa13512946c683c9788d956cbc1b",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -123,9 +186,16 @@
" bar.set_description(f\"Loss {loss:0.06f}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As sanity check we plot the loss to see if there is any funny business to investigate"
]
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 18,
"metadata": {},
"outputs": [
{
Expand All @@ -146,9 +216,21 @@
"ax.set_ylabel(\"Batched Loss (Hartree)\");"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The loss is decreasing so looks good!\n",
"\n",
"Finally we can evaluate the total energy by once again using our now good friend the \n",
"vectorising map to compute the `nuclear_energy` for each conformation. This is added\n",
"to the electronic energy we minimised above. Finally we can plot the dissociation \n",
"curve which looks like something you can find in your nearest chemistry textbook."
]
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 19,
"metadata": {},
"outputs": [
{
Expand Down
2 changes: 1 addition & 1 deletion docs/tour.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"We've measured a nearly 8X speedup...There are several gotchas of course that will\n",
"We've measured a nearly 8X speedup...There are several gotchas of course that will be\n",
"explored in due time...stay tuned for more!"
]
},
Expand Down

0 comments on commit cece2cf

Please sign in to comment.