diff --git a/docs/tour.ipynb b/docs/tour.ipynb
index 2590b2f..d50b398 100644
--- a/docs/tour.ipynb
+++ b/docs/tour.ipynb
@@ -1,1154 +1,1187 @@
{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {
- "tags": [
- "remove-input"
- ]
- },
- "outputs": [],
- "source": [
- "# Copyright (c) 2024 Graphcore Ltd. All rights reserved."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Brief Tour of MESS\n",
- "\n",
- "MESS is a modular toolkit for exploring the exciting interface between machine\n",
- "learning, electronic structure, and algorithms.\n",
- "\n",
- "To begin our tour we build a single water molecule.\n",
- "Each atom is represented an atomic number $Z_i$ and a position in Cartesian\n",
- "coordinates $(x_i, y_i, z_i)$. In MESS we collect atoms into a `Structure` and we\n",
- "provide a few examples built by the `molecule` function. \n",
- "MESS is designed for interactive exploration so in a notebook environment a `Structure`\n",
- "object will display a 3D visualisation\n",
- "\n",
- ":::{note}\n",
- "The following code cell will install MESS into the Google Colab runtime.\n",
- "Select the 🚀 in the toolbar above to try this out!\n",
- ":::"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "NNyAVJ9Lo-PO",
- "outputId": "83f4b81f-eb29-4732-9a51-9d9f24fff739",
- "tags": [
- "hide-cell"
- ]
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Collecting git+https://github.com/graphcore-research/mess.git\n",
- " Cloning https://github.com/graphcore-research/mess.git to /tmp/pip-req-build-qdku98_q\n",
- " Running command git clone --filter=blob:none --quiet https://github.com/graphcore-research/mess.git /tmp/pip-req-build-qdku98_q\n",
- " Resolved https://github.com/graphcore-research/mess.git to commit de2f014c8f24699ffdb1b0be2c94be277c19ff1a\n",
- " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
- " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
- " Installing backend dependencies ... \u001b[?25l\u001b[?25hdone\n",
- " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
- "Collecting pyquante2@ git+https://github.com/rpmuller/pyquante2@pure (from mess==0.0.0)\n",
- " Cloning https://github.com/rpmuller/pyquante2 (to revision pure) to /tmp/pip-install-wa358_g8/pyquante2_9ce341ce0b5142878b1e204279e4b4ad\n",
- " Running command git clone --filter=blob:none --quiet https://github.com/rpmuller/pyquante2 /tmp/pip-install-wa358_g8/pyquante2_9ce341ce0b5142878b1e204279e4b4ad\n",
- " Running command git checkout -b pure --track origin/pure\n",
- " Switched to a new branch 'pure'\n",
- " Branch 'pure' set up to track remote branch 'pure' from 'origin'.\n",
- " Resolved https://github.com/rpmuller/pyquante2 to commit 822a1755c83f1730b1b063bc4ab2580a23342c02\n",
- " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
- "Requirement already satisfied: equinox in /usr/local/lib/python3.10/dist-packages (from mess==0.0.0) (0.11.4)\n",
- "Requirement already satisfied: jax[cpu] in /usr/local/lib/python3.10/dist-packages (from mess==0.0.0) (0.4.26)\n",
- "Requirement already satisfied: jaxtyping in /usr/local/lib/python3.10/dist-packages (from mess==0.0.0) (0.2.28)\n",
- "Requirement already satisfied: more-itertools in /usr/local/lib/python3.10/dist-packages (from mess==0.0.0) (10.1.0)\n",
- "Requirement already satisfied: optax in /usr/local/lib/python3.10/dist-packages (from mess==0.0.0) (0.2.2)\n",
- "Requirement already satisfied: optimistix in /usr/local/lib/python3.10/dist-packages (from mess==0.0.0) (0.0.6)\n",
- "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from mess==0.0.0) (2.0.3)\n",
- "Requirement already satisfied: periodictable in /usr/local/lib/python3.10/dist-packages (from mess==0.0.0) (1.7.0)\n",
- "Requirement already satisfied: pyarrow in /usr/local/lib/python3.10/dist-packages (from mess==0.0.0) (14.0.2)\n",
- "Requirement already satisfied: pyscf in /usr/local/lib/python3.10/dist-packages (from mess==0.0.0) (2.5.0)\n",
- "Requirement already satisfied: py3Dmol in /usr/local/lib/python3.10/dist-packages (from mess==0.0.0) (2.1.0)\n",
- "Requirement already satisfied: basis-set-exchange in /usr/local/lib/python3.10/dist-packages (from mess==0.0.0) (0.9.1)\n",
- "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from mess==0.0.0) (1.12)\n",
- "Requirement already satisfied: jsonschema in /usr/local/lib/python3.10/dist-packages (from basis-set-exchange->mess==0.0.0) (4.19.2)\n",
- "Requirement already satisfied: argcomplete in /usr/local/lib/python3.10/dist-packages (from basis-set-exchange->mess==0.0.0) (3.3.0)\n",
- "Requirement already satisfied: regex in /usr/local/lib/python3.10/dist-packages (from basis-set-exchange->mess==0.0.0) (2023.12.25)\n",
- "Requirement already satisfied: unidecode in /usr/local/lib/python3.10/dist-packages (from basis-set-exchange->mess==0.0.0) (1.3.8)\n",
- "Requirement already satisfied: typing-extensions>=4.5.0 in /usr/local/lib/python3.10/dist-packages (from equinox->mess==0.0.0) (4.11.0)\n",
- "Requirement already satisfied: numpy>=1.20.0 in /usr/local/lib/python3.10/dist-packages (from jaxtyping->mess==0.0.0) (1.25.2)\n",
- "Requirement already satisfied: typeguard==2.13.3 in /usr/local/lib/python3.10/dist-packages (from jaxtyping->mess==0.0.0) (2.13.3)\n",
- "Requirement already satisfied: ml-dtypes>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from jax[cpu]->mess==0.0.0) (0.2.0)\n",
- "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.10/dist-packages (from jax[cpu]->mess==0.0.0) (3.3.0)\n",
- "Requirement already satisfied: scipy>=1.9 in /usr/local/lib/python3.10/dist-packages (from jax[cpu]->mess==0.0.0) (1.11.4)\n",
- "Requirement already satisfied: jaxlib==0.4.26 in /usr/local/lib/python3.10/dist-packages (from jax[cpu]->mess==0.0.0) (0.4.26+cuda12.cudnn89)\n",
- "Requirement already satisfied: absl-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from optax->mess==0.0.0) (1.4.0)\n",
- "Requirement already satisfied: chex>=0.1.86 in /usr/local/lib/python3.10/dist-packages (from optax->mess==0.0.0) (0.1.86)\n",
- "Requirement already satisfied: lineax>=0.0.4 in /usr/local/lib/python3.10/dist-packages (from optimistix->mess==0.0.0) (0.0.5)\n",
- "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->mess==0.0.0) (2.8.2)\n",
- "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->mess==0.0.0) (2023.4)\n",
- "Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->mess==0.0.0) (2024.1)\n",
- "Requirement already satisfied: pyparsing in /usr/local/lib/python3.10/dist-packages (from periodictable->mess==0.0.0) (3.1.2)\n",
- "Requirement already satisfied: h5py>=2.7 in /usr/local/lib/python3.10/dist-packages (from pyscf->mess==0.0.0) (3.9.0)\n",
- "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from pyscf->mess==0.0.0) (67.7.2)\n",
- "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->mess==0.0.0) (1.3.0)\n",
- "Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.10/dist-packages (from chex>=0.1.86->optax->mess==0.0.0) (0.12.1)\n",
- "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->mess==0.0.0) (1.16.0)\n",
- "Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.10/dist-packages (from jsonschema->basis-set-exchange->mess==0.0.0) (23.2.0)\n",
- "Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema->basis-set-exchange->mess==0.0.0) (2023.12.1)\n",
- "Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema->basis-set-exchange->mess==0.0.0) (0.34.0)\n",
- "Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema->basis-set-exchange->mess==0.0.0) (0.18.0)\n"
- ]
- }
- ],
- "source": [
- "import sys\n",
- "\n",
- "if 'google.colab' in sys.modules:\n",
- " !pip install git+https://github.com/graphcore-research/mess.git\n",
- "\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 515
- },
- "id": "rb2pN_KOo7a5",
- "outputId": "6b7721d1-1947-4b63-a1b1-e36e7f358148"
- },
- "outputs": [
- {
- "data": {
- "application/3dmoljs_load.v0": "
\n
You appear to be running in JupyterLab (or JavaScript failed to load for some other reason). You need to install the 3dmol extension:
\n jupyter labextension install jupyterlab_3dmol
\n
\n",
- "text/html": [
- "\n",
- "
You appear to be running in JupyterLab (or JavaScript failed to load for some other reason). You need to install the 3dmol extension:
\n",
- " jupyter labextension install jupyterlab_3dmol
\n",
- "
\n",
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/plain": [
- "Structure(atomic_number=i64[3](numpy), position=f64[3,3](numpy))"
- ]
- },
- "execution_count": 2,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "from mess import molecule\n",
- "\n",
- "mol = molecule(\"water\")\n",
- "mol"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "72Lm2l0Jo7a6"
- },
- "source": [
- "MESS represents electrons using the Linear Combination of Atomic Orbitals ([LCAO](https://en.wikipedia.org/wiki/Linear_combination_of_atomic_orbitals)) method.\n",
- "We rely on the [Basis Set Exchange](https://www.basissetexchange.org/) project to\n",
- "provide access to the full range of previously calculated Gaussian Type Orbital parameters."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "KVb5kWdMo7a6",
- "outputId": "c25a95e1-853e-42dc-dda6-fa605cc7c0d6"
- },
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " orbital | \n",
- " coefficient | \n",
- " norm | \n",
- " center | \n",
- " lmn | \n",
- " alpha | \n",
- "
\n",
- " \n",
- " primitive | \n",
- " | \n",
- " | \n",
- " | \n",
- " | \n",
- " | \n",
- " | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 0 | \n",
- " 0 | \n",
- " 0.001831 | \n",
- " 454.227134 | \n",
- " [0.0, 0.0, 0.2201531] | \n",
- " [0, 0, 0] | \n",
- " 5484.671875 | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " 0 | \n",
- " 0.013950 | \n",
- " 109.734524 | \n",
- " [0.0, 0.0, 0.2201531] | \n",
- " [0, 0, 0] | \n",
- " 825.234924 | \n",
- "
\n",
- " \n",
- " 2 | \n",
- " 0 | \n",
- " 0.068445 | \n",
- " 36.191769 | \n",
- " [0.0, 0.0, 0.2201531] | \n",
- " [0, 0, 0] | \n",
- " 188.046951 | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " 0 | \n",
- " 0.232714 | \n",
- " 13.992611 | \n",
- " [0.0, 0.0, 0.2201531] | \n",
- " [0, 0, 0] | \n",
- " 52.964500 | \n",
- "
\n",
- " \n",
- " 4 | \n",
- " 0 | \n",
- " 0.470193 | \n",
- " 5.939888 | \n",
- " [0.0, 0.0, 0.2201531] | \n",
- " [0, 0, 0] | \n",
- " 16.897570 | \n",
- "
\n",
- " \n",
- " 5 | \n",
- " 0 | \n",
- " 0.358521 | \n",
- " 2.663549 | \n",
- " [0.0, 0.0, 0.2201531] | \n",
- " [0, 0, 0] | \n",
- " 5.799635 | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " 1 | \n",
- " -0.110778 | \n",
- " 5.578151 | \n",
- " [0.0, 0.0, 0.2201531] | \n",
- " [0, 0, 0] | \n",
- " 15.539617 | \n",
- "
\n",
- " \n",
- " 7 | \n",
- " 1 | \n",
- " -0.148026 | \n",
- " 1.862649 | \n",
- " [0.0, 0.0, 0.2201531] | \n",
- " [0, 0, 0] | \n",
- " 3.599934 | \n",
- "
\n",
- " \n",
- " 8 | \n",
- " 1 | \n",
- " 1.130767 | \n",
- " 0.720049 | \n",
- " [0.0, 0.0, 0.2201531] | \n",
- " [0, 0, 0] | \n",
- " 1.013762 | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " 2 | \n",
- " 1.000000 | \n",
- " 0.266956 | \n",
- " [0.0, 0.0, 0.2201531] | \n",
- " [0, 0, 0] | \n",
- " 0.270006 | \n",
- "
\n",
- " \n",
- " 10 | \n",
- " 3 | \n",
- " 0.070874 | \n",
- " 43.978502 | \n",
- " [0.0, 0.0, 0.2201531] | \n",
- " [1, 0, 0] | \n",
- " 15.539617 | \n",
- "
\n",
- " \n",
- " 11 | \n",
- " 3 | \n",
- " 0.339753 | \n",
- " 7.068189 | \n",
- " [0.0, 0.0, 0.2201531] | \n",
- " [1, 0, 0] | \n",
- " 3.599934 | \n",
- "
\n",
- " \n",
- " 12 | \n",
- " 3 | \n",
- " 0.727159 | \n",
- " 1.449973 | \n",
- " [0.0, 0.0, 0.2201531] | \n",
- " [1, 0, 0] | \n",
- " 1.013762 | \n",
- "
\n",
- " \n",
- " 13 | \n",
- " 4 | \n",
- " 0.070874 | \n",
- " 43.978502 | \n",
- " [0.0, 0.0, 0.2201531] | \n",
- " [0, 1, 0] | \n",
- " 15.539617 | \n",
- "
\n",
- " \n",
- " 14 | \n",
- " 4 | \n",
- " 0.339753 | \n",
- " 7.068189 | \n",
- " [0.0, 0.0, 0.2201531] | \n",
- " [0, 1, 0] | \n",
- " 3.599934 | \n",
- "
\n",
- " \n",
- " 15 | \n",
- " 4 | \n",
- " 0.727159 | \n",
- " 1.449973 | \n",
- " [0.0, 0.0, 0.2201531] | \n",
- " [0, 1, 0] | \n",
- " 1.013762 | \n",
- "
\n",
- " \n",
- " 16 | \n",
- " 5 | \n",
- " 0.070874 | \n",
- " 43.978502 | \n",
- " [0.0, 0.0, 0.2201531] | \n",
- " [0, 0, 1] | \n",
- " 15.539617 | \n",
- "
\n",
- " \n",
- " 17 | \n",
- " 5 | \n",
- " 0.339753 | \n",
- " 7.068189 | \n",
- " [0.0, 0.0, 0.2201531] | \n",
- " [0, 0, 1] | \n",
- " 3.599934 | \n",
- "
\n",
- " \n",
- " 18 | \n",
- " 5 | \n",
- " 0.727159 | \n",
- " 1.449973 | \n",
- " [0.0, 0.0, 0.2201531] | \n",
- " [0, 0, 1] | \n",
- " 1.013762 | \n",
- "
\n",
- " \n",
- " 19 | \n",
- " 6 | \n",
- " 1.000000 | \n",
- " 0.277432 | \n",
- " [0.0, 0.0, 0.2201531] | \n",
- " [1, 0, 0] | \n",
- " 0.270006 | \n",
- "
\n",
- " \n",
- " 20 | \n",
- " 7 | \n",
- " 1.000000 | \n",
- " 0.277432 | \n",
- " [0.0, 0.0, 0.2201531] | \n",
- " [0, 1, 0] | \n",
- " 0.270006 | \n",
- "
\n",
- " \n",
- " 21 | \n",
- " 8 | \n",
- " 1.000000 | \n",
- " 0.277432 | \n",
- " [0.0, 0.0, 0.2201531] | \n",
- " [0, 0, 1] | \n",
- " 0.270006 | \n",
- "
\n",
- " \n",
- " 22 | \n",
- " 9 | \n",
- " 0.033495 | \n",
- " 6.417017 | \n",
- " [0.0, 1.4539553, -0.8808013] | \n",
- " [0, 0, 0] | \n",
- " 18.731136 | \n",
- "
\n",
- " \n",
- " 23 | \n",
- " 9 | \n",
- " 0.234727 | \n",
- " 1.553171 | \n",
- " [0.0, 1.4539553, -0.8808013] | \n",
- " [0, 0, 0] | \n",
- " 2.825394 | \n",
- "
\n",
- " \n",
- " 24 | \n",
- " 9 | \n",
- " 0.813757 | \n",
- " 0.510043 | \n",
- " [0.0, 1.4539553, -0.8808013] | \n",
- " [0, 0, 0] | \n",
- " 0.640122 | \n",
- "
\n",
- " \n",
- " 25 | \n",
- " 10 | \n",
- " 1.000000 | \n",
- " 0.181381 | \n",
- " [0.0, 1.4539553, -0.8808013] | \n",
- " [0, 0, 0] | \n",
- " 0.161278 | \n",
- "
\n",
- " \n",
- " 26 | \n",
- " 11 | \n",
- " 0.033495 | \n",
- " 6.417017 | \n",
- " [0.0, -1.4539553, -0.8808013] | \n",
- " [0, 0, 0] | \n",
- " 18.731136 | \n",
- "
\n",
- " \n",
- " 27 | \n",
- " 11 | \n",
- " 0.234727 | \n",
- " 1.553171 | \n",
- " [0.0, -1.4539553, -0.8808013] | \n",
- " [0, 0, 0] | \n",
- " 2.825394 | \n",
- "
\n",
- " \n",
- " 28 | \n",
- " 11 | \n",
- " 0.813757 | \n",
- " 0.510043 | \n",
- " [0.0, -1.4539553, -0.8808013] | \n",
- " [0, 0, 0] | \n",
- " 0.640122 | \n",
- "
\n",
- " \n",
- " 29 | \n",
- " 12 | \n",
- " 1.000000 | \n",
- " 0.181381 | \n",
- " [0.0, -1.4539553, -0.8808013] | \n",
- " [0, 0, 0] | \n",
- " 0.161278 | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " primitive orbital coefficient norm center lmn alpha\n",
- "----------- --------- ------------- ---------- ---------------------------------- ------- -----------\n",
- " 0 0 0.00183107 454.227 [0. 0. 0.2201531] [0 0 0] 5484.67\n",
- " 1 0 0.0139502 109.735 [0. 0. 0.2201531] [0 0 0] 825.235\n",
- " 2 0 0.0684451 36.1918 [0. 0. 0.2201531] [0 0 0] 188.047\n",
- " 3 0 0.232714 13.9926 [0. 0. 0.2201531] [0 0 0] 52.9645\n",
- " 4 0 0.470193 5.93989 [0. 0. 0.2201531] [0 0 0] 16.8976\n",
- " 5 0 0.358521 2.66355 [0. 0. 0.2201531] [0 0 0] 5.79964\n",
- " 6 1 -0.110778 5.57815 [0. 0. 0.2201531] [0 0 0] 15.5396\n",
- " 7 1 -0.148026 1.86265 [0. 0. 0.2201531] [0 0 0] 3.59993\n",
- " 8 1 1.13077 0.720049 [0. 0. 0.2201531] [0 0 0] 1.01376\n",
- " 9 2 1 0.266956 [0. 0. 0.2201531] [0 0 0] 0.270006\n",
- " 10 3 0.0708743 43.9785 [0. 0. 0.2201531] [1 0 0] 15.5396\n",
- " 11 3 0.339753 7.06819 [0. 0. 0.2201531] [1 0 0] 3.59993\n",
- " 12 3 0.727159 1.44997 [0. 0. 0.2201531] [1 0 0] 1.01376\n",
- " 13 4 0.0708743 43.9785 [0. 0. 0.2201531] [0 1 0] 15.5396\n",
- " 14 4 0.339753 7.06819 [0. 0. 0.2201531] [0 1 0] 3.59993\n",
- " 15 4 0.727159 1.44997 [0. 0. 0.2201531] [0 1 0] 1.01376\n",
- " 16 5 0.0708743 43.9785 [0. 0. 0.2201531] [0 0 1] 15.5396\n",
- " 17 5 0.339753 7.06819 [0. 0. 0.2201531] [0 0 1] 3.59993\n",
- " 18 5 0.727159 1.44997 [0. 0. 0.2201531] [0 0 1] 1.01376\n",
- " 19 6 1 0.277432 [0. 0. 0.2201531] [1 0 0] 0.270006\n",
- " 20 7 1 0.277432 [0. 0. 0.2201531] [0 1 0] 0.270006\n",
- " 21 8 1 0.277432 [0. 0. 0.2201531] [0 0 1] 0.270006\n",
- " 22 9 0.0334946 6.41702 [ 0. 1.4539553 -0.8808013] [0 0 0] 18.7311\n",
- " 23 9 0.234727 1.55317 [ 0. 1.4539553 -0.8808013] [0 0 0] 2.82539\n",
- " 24 9 0.813757 0.510043 [ 0. 1.4539553 -0.8808013] [0 0 0] 0.640122\n",
- " 25 10 1 0.181381 [ 0. 1.4539553 -0.8808013] [0 0 0] 0.161278\n",
- " 26 11 0.0334946 6.41702 [ 0. -1.4539553 -0.8808013] [0 0 0] 18.7311\n",
- " 27 11 0.234727 1.55317 [ 0. -1.4539553 -0.8808013] [0 0 0] 2.82539\n",
- " 28 11 0.813757 0.510043 [ 0. -1.4539553 -0.8808013] [0 0 0] 0.640122\n",
- " 29 12 1 0.181381 [ 0. -1.4539553 -0.8808013] [0 0 0] 0.161278"
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "from mess import basisset\n",
- "\n",
- "basis_name = \"6-31g\"\n",
- "basis = basisset(mol, basis_name)\n",
- "basis"
- ]
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "id": "Hr9cSq_WZNb4",
+ "tags": [
+ "remove-input"
+ ]
+ },
+ "outputs": [],
+ "source": [
+ "# Copyright (c) 2024 Graphcore Ltd. All rights reserved."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "qR3UexoZZNb4"
+ },
+ "source": [
+ "# Brief Tour of MESS\n",
+ "\n",
+ "MESS is a modular toolkit for exploring the exciting interface between machine\n",
+ "learning, electronic structure, and algorithms.\n",
+ "\n",
+ "To begin our tour we build a single water molecule.\n",
+ "Each atom is represented an atomic number $Z_i$ and a position in Cartesian\n",
+ "coordinates $(x_i, y_i, z_i)$. In MESS we collect atoms into a `Structure` and we\n",
+ "provide a few examples built by the `molecule` function. \n",
+ "MESS is designed for interactive exploration so in a notebook environment a `Structure`\n",
+ "object will display a 3D visualisation\n",
+ "\n",
+ ":::{note}\n",
+ "The following code cell will install MESS into the Google Colab runtime.\n",
+ "Select the 🚀 in the toolbar above to try this out!\n",
+ ":::"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
},
+ "id": "NNyAVJ9Lo-PO",
+ "outputId": "e52fded0-11c4-4375-c546-9477e5bedb6c",
+ "tags": [
+ "hide-cell"
+ ]
+ },
+ "outputs": [
{
- "cell_type": "markdown",
- "metadata": {
- "id": "rlBloxuGo7a7"
- },
- "source": [
- "We now have all the pieces to run our first electronic structure simulation.\n",
- "This is done in two steps:\n",
- "* build a Hamiltonian by selecting a treatment for the quantum-mechanical exchange and correlation.\n",
- "* find the molecular orbital coefficients $C$ that minimise the energy of this Hamiltonian\n",
- "\n",
- "In the following we use the popular\n",
- "[PBE exchange-correlation functional](https://doi.org/10.1103/PhysRevLett.77.3865)\n",
- "of Density Functional Theory (DFT) to model the quantum-mechanical electron interactions.\n",
- "\n",
- "By default the minimisation routine will be compiled by JAX and executed on any available\n",
- "hardware accelerator (e.g. GPU/TPU)."
- ]
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Collecting git+https://github.com/graphcore-research/mess.git\n",
+ " Cloning https://github.com/graphcore-research/mess.git to /tmp/pip-req-build-7ks3ewjy\n",
+ " Running command git clone --filter=blob:none --quiet https://github.com/graphcore-research/mess.git /tmp/pip-req-build-7ks3ewjy\n",
+ " Resolved https://github.com/graphcore-research/mess.git to commit 2d51ea7d79a89387b5c7af76452c0abdba5dab64\n",
+ " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
+ " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
+ " Installing backend dependencies ... \u001b[?25l\u001b[?25hdone\n",
+ " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+ "Collecting pyquante2@ git+https://github.com/rpmuller/pyquante2@pure (from mess==0.0.0)\n",
+ " Cloning https://github.com/rpmuller/pyquante2 (to revision pure) to /tmp/pip-install-ljts8vhb/pyquante2_fba2469426054776b2b72c9bcec5d3d8\n",
+ " Running command git clone --filter=blob:none --quiet https://github.com/rpmuller/pyquante2 /tmp/pip-install-ljts8vhb/pyquante2_fba2469426054776b2b72c9bcec5d3d8\n",
+ " Running command git checkout -b pure --track origin/pure\n",
+ " Switched to a new branch 'pure'\n",
+ " Branch 'pure' set up to track remote branch 'pure' from 'origin'.\n",
+ " Resolved https://github.com/rpmuller/pyquante2 to commit 822a1755c83f1730b1b063bc4ab2580a23342c02\n",
+ " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+ "Requirement already satisfied: equinox in /usr/local/lib/python3.10/dist-packages (from mess==0.0.0) (0.11.4)\n",
+ "Requirement already satisfied: jax[cpu] in /usr/local/lib/python3.10/dist-packages (from mess==0.0.0) (0.4.26)\n",
+ "Requirement already satisfied: jaxtyping in /usr/local/lib/python3.10/dist-packages (from mess==0.0.0) (0.2.28)\n",
+ "Requirement already satisfied: more-itertools in /usr/local/lib/python3.10/dist-packages (from mess==0.0.0) (10.1.0)\n",
+ "Requirement already satisfied: optax in /usr/local/lib/python3.10/dist-packages (from mess==0.0.0) (0.2.2)\n",
+ "Requirement already satisfied: optimistix in /usr/local/lib/python3.10/dist-packages (from mess==0.0.0) (0.0.6)\n",
+ "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from mess==0.0.0) (2.0.3)\n",
+ "Requirement already satisfied: periodictable in /usr/local/lib/python3.10/dist-packages (from mess==0.0.0) (1.7.0)\n",
+ "Requirement already satisfied: pyarrow in /usr/local/lib/python3.10/dist-packages (from mess==0.0.0) (14.0.2)\n",
+ "Requirement already satisfied: pyscf in /usr/local/lib/python3.10/dist-packages (from mess==0.0.0) (2.5.0)\n",
+ "Requirement already satisfied: py3Dmol in /usr/local/lib/python3.10/dist-packages (from mess==0.0.0) (2.1.0)\n",
+ "Requirement already satisfied: basis-set-exchange in /usr/local/lib/python3.10/dist-packages (from mess==0.0.0) (0.9.1)\n",
+ "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from mess==0.0.0) (1.12)\n",
+ "Requirement already satisfied: jsonschema in /usr/local/lib/python3.10/dist-packages (from basis-set-exchange->mess==0.0.0) (4.19.2)\n",
+ "Requirement already satisfied: argcomplete in /usr/local/lib/python3.10/dist-packages (from basis-set-exchange->mess==0.0.0) (3.3.0)\n",
+ "Requirement already satisfied: regex in /usr/local/lib/python3.10/dist-packages (from basis-set-exchange->mess==0.0.0) (2023.12.25)\n",
+ "Requirement already satisfied: unidecode in /usr/local/lib/python3.10/dist-packages (from basis-set-exchange->mess==0.0.0) (1.3.8)\n",
+ "Requirement already satisfied: typing-extensions>=4.5.0 in /usr/local/lib/python3.10/dist-packages (from equinox->mess==0.0.0) (4.11.0)\n",
+ "Requirement already satisfied: numpy>=1.20.0 in /usr/local/lib/python3.10/dist-packages (from jaxtyping->mess==0.0.0) (1.25.2)\n",
+ "Requirement already satisfied: typeguard==2.13.3 in /usr/local/lib/python3.10/dist-packages (from jaxtyping->mess==0.0.0) (2.13.3)\n",
+ "Requirement already satisfied: ml-dtypes>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from jax[cpu]->mess==0.0.0) (0.2.0)\n",
+ "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.10/dist-packages (from jax[cpu]->mess==0.0.0) (3.3.0)\n",
+ "Requirement already satisfied: scipy>=1.9 in /usr/local/lib/python3.10/dist-packages (from jax[cpu]->mess==0.0.0) (1.11.4)\n",
+ "Requirement already satisfied: jaxlib==0.4.26 in /usr/local/lib/python3.10/dist-packages (from jax[cpu]->mess==0.0.0) (0.4.26+cuda12.cudnn89)\n",
+ "Requirement already satisfied: absl-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from optax->mess==0.0.0) (1.4.0)\n",
+ "Requirement already satisfied: chex>=0.1.86 in /usr/local/lib/python3.10/dist-packages (from optax->mess==0.0.0) (0.1.86)\n",
+ "Requirement already satisfied: lineax>=0.0.4 in /usr/local/lib/python3.10/dist-packages (from optimistix->mess==0.0.0) (0.0.5)\n",
+ "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->mess==0.0.0) (2.8.2)\n",
+ "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->mess==0.0.0) (2023.4)\n",
+ "Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->mess==0.0.0) (2024.1)\n",
+ "Requirement already satisfied: pyparsing in /usr/local/lib/python3.10/dist-packages (from periodictable->mess==0.0.0) (3.1.2)\n",
+ "Requirement already satisfied: h5py>=2.7 in /usr/local/lib/python3.10/dist-packages (from pyscf->mess==0.0.0) (3.9.0)\n",
+ "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from pyscf->mess==0.0.0) (67.7.2)\n",
+ "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->mess==0.0.0) (1.3.0)\n",
+ "Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.10/dist-packages (from chex>=0.1.86->optax->mess==0.0.0) (0.12.1)\n",
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->mess==0.0.0) (1.16.0)\n",
+ "Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.10/dist-packages (from jsonschema->basis-set-exchange->mess==0.0.0) (23.2.0)\n",
+ "Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema->basis-set-exchange->mess==0.0.0) (2023.12.1)\n",
+ "Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema->basis-set-exchange->mess==0.0.0) (0.35.0)\n",
+ "Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema->basis-set-exchange->mess==0.0.0) (0.18.0)\n"
+ ]
+ }
+ ],
+ "source": [
+ "import sys\n",
+ "\n",
+ "if 'google.colab' in sys.modules:\n",
+ " !pip install git+https://github.com/graphcore-research/mess.git\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 514
},
+ "id": "rb2pN_KOo7a5",
+ "outputId": "34afe893-13ee-42c4-bd31-df9eabd4e6eb"
+ },
+ "outputs": [
{
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "qaJNzYL2o7a7",
- "outputId": "1e58bd33-5a31-411d-b572-03c36f278038"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "-76.29898894943122"
- ]
- },
- "execution_count": 5,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "from mess import minimise, Hamiltonian\n",
- "\n",
- "H = Hamiltonian(basis, xc_method=\"pbe\")\n",
- "E, C, sol = minimise(H)\n",
- "float(E)"
+ "data": {
+ "application/3dmoljs_load.v0": "\n
3Dmol.js failed to load for some reason. Please check your browser console for error messages.
\n
\n",
+ "text/html": [
+ "\n",
+ "
3Dmol.js failed to load for some reason. Please check your browser console for error messages.
\n",
+ "
\n",
+ ""
]
+ },
+ "metadata": {},
+ "output_type": "display_data"
},
{
- "cell_type": "markdown",
- "metadata": {
- "id": "FXoOtEVIo7a7"
- },
- "source": [
- "We can visualise the electron density $\\rho(\\mathbf{r})$ from the solution we found.\n",
- "The electron cloud is approximately mickey mouse head shaped to use the technical term."
+ "data": {
+ "text/plain": [
+ "Structure(atomic_number=i64[3](numpy), position=f64[3,3](numpy))"
]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from mess import molecule\n",
+ "\n",
+ "mol = molecule(\"water\")\n",
+ "mol"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "72Lm2l0Jo7a6"
+ },
+ "source": [
+ "MESS represents electrons using the Linear Combination of Atomic Orbitals ([LCAO](https://en.wikipedia.org/wiki/Linear_combination_of_atomic_orbitals)) method.\n",
+ "We rely on the [Basis Set Exchange](https://www.basissetexchange.org/) project to\n",
+ "provide access to the full range of previously calculated Gaussian Type Orbital parameters."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 1000
},
+ "id": "KVb5kWdMo7a6",
+ "outputId": "dcb29b7d-a5fc-4bd3-820b-fac3e389802d"
+ },
+ "outputs": [
{
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 515
- },
- "id": "r0lgDmDio7a7",
- "outputId": "c9b93377-660d-4de7-d12f-d8e019aa0460"
- },
- "outputs": [
- {
- "data": {
- "application/3dmoljs_load.v0": "\n
3Dmol.js failed to load for some reason. Please check your browser console for error messages.
\n
\n",
- "text/html": [
- "\n",
- "
3Dmol.js failed to load for some reason. Please check your browser console for error messages.
\n",
- "
\n",
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/plain": [
- ""
- ]
- },
- "execution_count": 6,
- "metadata": {},
- "output_type": "execute_result"
- }
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " orbital | \n",
+ " coefficient | \n",
+ " norm | \n",
+ " center | \n",
+ " lmn | \n",
+ " alpha | \n",
+ "
\n",
+ " \n",
+ " primitive | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0.001831 | \n",
+ " 454.227134 | \n",
+ " [0.0, 0.0, 0.2201531] | \n",
+ " [0, 0, 0] | \n",
+ " 5484.671875 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 0.013950 | \n",
+ " 109.734524 | \n",
+ " [0.0, 0.0, 0.2201531] | \n",
+ " [0, 0, 0] | \n",
+ " 825.234924 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0 | \n",
+ " 0.068445 | \n",
+ " 36.191769 | \n",
+ " [0.0, 0.0, 0.2201531] | \n",
+ " [0, 0, 0] | \n",
+ " 188.046951 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0 | \n",
+ " 0.232714 | \n",
+ " 13.992611 | \n",
+ " [0.0, 0.0, 0.2201531] | \n",
+ " [0, 0, 0] | \n",
+ " 52.964500 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0 | \n",
+ " 0.470193 | \n",
+ " 5.939888 | \n",
+ " [0.0, 0.0, 0.2201531] | \n",
+ " [0, 0, 0] | \n",
+ " 16.897570 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0 | \n",
+ " 0.358521 | \n",
+ " 2.663549 | \n",
+ " [0.0, 0.0, 0.2201531] | \n",
+ " [0, 0, 0] | \n",
+ " 5.799635 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 1 | \n",
+ " -0.110778 | \n",
+ " 5.578151 | \n",
+ " [0.0, 0.0, 0.2201531] | \n",
+ " [0, 0, 0] | \n",
+ " 15.539617 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 1 | \n",
+ " -0.148026 | \n",
+ " 1.862649 | \n",
+ " [0.0, 0.0, 0.2201531] | \n",
+ " [0, 0, 0] | \n",
+ " 3.599934 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 1 | \n",
+ " 1.130767 | \n",
+ " 0.720049 | \n",
+ " [0.0, 0.0, 0.2201531] | \n",
+ " [0, 0, 0] | \n",
+ " 1.013762 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 2 | \n",
+ " 1.000000 | \n",
+ " 0.266956 | \n",
+ " [0.0, 0.0, 0.2201531] | \n",
+ " [0, 0, 0] | \n",
+ " 0.270006 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 3 | \n",
+ " 0.070874 | \n",
+ " 43.978502 | \n",
+ " [0.0, 0.0, 0.2201531] | \n",
+ " [1, 0, 0] | \n",
+ " 15.539617 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 3 | \n",
+ " 0.339753 | \n",
+ " 7.068189 | \n",
+ " [0.0, 0.0, 0.2201531] | \n",
+ " [1, 0, 0] | \n",
+ " 3.599934 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 3 | \n",
+ " 0.727159 | \n",
+ " 1.449973 | \n",
+ " [0.0, 0.0, 0.2201531] | \n",
+ " [1, 0, 0] | \n",
+ " 1.013762 | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " 4 | \n",
+ " 0.070874 | \n",
+ " 43.978502 | \n",
+ " [0.0, 0.0, 0.2201531] | \n",
+ " [0, 1, 0] | \n",
+ " 15.539617 | \n",
+ "
\n",
+ " \n",
+ " 14 | \n",
+ " 4 | \n",
+ " 0.339753 | \n",
+ " 7.068189 | \n",
+ " [0.0, 0.0, 0.2201531] | \n",
+ " [0, 1, 0] | \n",
+ " 3.599934 | \n",
+ "
\n",
+ " \n",
+ " 15 | \n",
+ " 4 | \n",
+ " 0.727159 | \n",
+ " 1.449973 | \n",
+ " [0.0, 0.0, 0.2201531] | \n",
+ " [0, 1, 0] | \n",
+ " 1.013762 | \n",
+ "
\n",
+ " \n",
+ " 16 | \n",
+ " 5 | \n",
+ " 0.070874 | \n",
+ " 43.978502 | \n",
+ " [0.0, 0.0, 0.2201531] | \n",
+ " [0, 0, 1] | \n",
+ " 15.539617 | \n",
+ "
\n",
+ " \n",
+ " 17 | \n",
+ " 5 | \n",
+ " 0.339753 | \n",
+ " 7.068189 | \n",
+ " [0.0, 0.0, 0.2201531] | \n",
+ " [0, 0, 1] | \n",
+ " 3.599934 | \n",
+ "
\n",
+ " \n",
+ " 18 | \n",
+ " 5 | \n",
+ " 0.727159 | \n",
+ " 1.449973 | \n",
+ " [0.0, 0.0, 0.2201531] | \n",
+ " [0, 0, 1] | \n",
+ " 1.013762 | \n",
+ "
\n",
+ " \n",
+ " 19 | \n",
+ " 6 | \n",
+ " 1.000000 | \n",
+ " 0.277432 | \n",
+ " [0.0, 0.0, 0.2201531] | \n",
+ " [1, 0, 0] | \n",
+ " 0.270006 | \n",
+ "
\n",
+ " \n",
+ " 20 | \n",
+ " 7 | \n",
+ " 1.000000 | \n",
+ " 0.277432 | \n",
+ " [0.0, 0.0, 0.2201531] | \n",
+ " [0, 1, 0] | \n",
+ " 0.270006 | \n",
+ "
\n",
+ " \n",
+ " 21 | \n",
+ " 8 | \n",
+ " 1.000000 | \n",
+ " 0.277432 | \n",
+ " [0.0, 0.0, 0.2201531] | \n",
+ " [0, 0, 1] | \n",
+ " 0.270006 | \n",
+ "
\n",
+ " \n",
+ " 22 | \n",
+ " 9 | \n",
+ " 0.033495 | \n",
+ " 6.417017 | \n",
+ " [0.0, 1.4539553, -0.8808013] | \n",
+ " [0, 0, 0] | \n",
+ " 18.731136 | \n",
+ "
\n",
+ " \n",
+ " 23 | \n",
+ " 9 | \n",
+ " 0.234727 | \n",
+ " 1.553171 | \n",
+ " [0.0, 1.4539553, -0.8808013] | \n",
+ " [0, 0, 0] | \n",
+ " 2.825394 | \n",
+ "
\n",
+ " \n",
+ " 24 | \n",
+ " 9 | \n",
+ " 0.813757 | \n",
+ " 0.510043 | \n",
+ " [0.0, 1.4539553, -0.8808013] | \n",
+ " [0, 0, 0] | \n",
+ " 0.640122 | \n",
+ "
\n",
+ " \n",
+ " 25 | \n",
+ " 10 | \n",
+ " 1.000000 | \n",
+ " 0.181381 | \n",
+ " [0.0, 1.4539553, -0.8808013] | \n",
+ " [0, 0, 0] | \n",
+ " 0.161278 | \n",
+ "
\n",
+ " \n",
+ " 26 | \n",
+ " 11 | \n",
+ " 0.033495 | \n",
+ " 6.417017 | \n",
+ " [0.0, -1.4539553, -0.8808013] | \n",
+ " [0, 0, 0] | \n",
+ " 18.731136 | \n",
+ "
\n",
+ " \n",
+ " 27 | \n",
+ " 11 | \n",
+ " 0.234727 | \n",
+ " 1.553171 | \n",
+ " [0.0, -1.4539553, -0.8808013] | \n",
+ " [0, 0, 0] | \n",
+ " 2.825394 | \n",
+ "
\n",
+ " \n",
+ " 28 | \n",
+ " 11 | \n",
+ " 0.813757 | \n",
+ " 0.510043 | \n",
+ " [0.0, -1.4539553, -0.8808013] | \n",
+ " [0, 0, 0] | \n",
+ " 0.640122 | \n",
+ "
\n",
+ " \n",
+ " 29 | \n",
+ " 12 | \n",
+ " 1.000000 | \n",
+ " 0.181381 | \n",
+ " [0.0, -1.4539553, -0.8808013] | \n",
+ " [0, 0, 0] | \n",
+ " 0.161278 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
],
- "source": [
- "import py3Dmol\n",
- "from mess.mesh import density, uniform_mesh\n",
- "from mess.plot import plot_volume, plot_molecule\n",
- "\n",
- "view = py3Dmol.view()\n",
- "plot_molecule(view, mol)\n",
- "mesh = uniform_mesh()\n",
- "rho = density(basis, mesh, basis.density_matrix(C))\n",
- "plot_volume(view, rho, mesh.axes)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "1ipZwc3go7a7"
- },
- "source": [
- "Performance is one of several goals of the MESS project so lets measure how long the\n",
- "energy minimisation takes"
+ "text/plain": [
+ " primitive orbital coefficient norm center lmn alpha\n",
+ "----------- --------- ------------- ---------- ---------------------------------- ------- -----------\n",
+ " 0 0 0.00183107 454.227 [0. 0. 0.2201531] [0 0 0] 5484.67\n",
+ " 1 0 0.0139502 109.735 [0. 0. 0.2201531] [0 0 0] 825.235\n",
+ " 2 0 0.0684451 36.1918 [0. 0. 0.2201531] [0 0 0] 188.047\n",
+ " 3 0 0.232714 13.9926 [0. 0. 0.2201531] [0 0 0] 52.9645\n",
+ " 4 0 0.470193 5.93989 [0. 0. 0.2201531] [0 0 0] 16.8976\n",
+ " 5 0 0.358521 2.66355 [0. 0. 0.2201531] [0 0 0] 5.79964\n",
+ " 6 1 -0.110778 5.57815 [0. 0. 0.2201531] [0 0 0] 15.5396\n",
+ " 7 1 -0.148026 1.86265 [0. 0. 0.2201531] [0 0 0] 3.59993\n",
+ " 8 1 1.13077 0.720049 [0. 0. 0.2201531] [0 0 0] 1.01376\n",
+ " 9 2 1 0.266956 [0. 0. 0.2201531] [0 0 0] 0.270006\n",
+ " 10 3 0.0708743 43.9785 [0. 0. 0.2201531] [1 0 0] 15.5396\n",
+ " 11 3 0.339753 7.06819 [0. 0. 0.2201531] [1 0 0] 3.59993\n",
+ " 12 3 0.727159 1.44997 [0. 0. 0.2201531] [1 0 0] 1.01376\n",
+ " 13 4 0.0708743 43.9785 [0. 0. 0.2201531] [0 1 0] 15.5396\n",
+ " 14 4 0.339753 7.06819 [0. 0. 0.2201531] [0 1 0] 3.59993\n",
+ " 15 4 0.727159 1.44997 [0. 0. 0.2201531] [0 1 0] 1.01376\n",
+ " 16 5 0.0708743 43.9785 [0. 0. 0.2201531] [0 0 1] 15.5396\n",
+ " 17 5 0.339753 7.06819 [0. 0. 0.2201531] [0 0 1] 3.59993\n",
+ " 18 5 0.727159 1.44997 [0. 0. 0.2201531] [0 0 1] 1.01376\n",
+ " 19 6 1 0.277432 [0. 0. 0.2201531] [1 0 0] 0.270006\n",
+ " 20 7 1 0.277432 [0. 0. 0.2201531] [0 1 0] 0.270006\n",
+ " 21 8 1 0.277432 [0. 0. 0.2201531] [0 0 1] 0.270006\n",
+ " 22 9 0.0334946 6.41702 [ 0. 1.4539553 -0.8808013] [0 0 0] 18.7311\n",
+ " 23 9 0.234727 1.55317 [ 0. 1.4539553 -0.8808013] [0 0 0] 2.82539\n",
+ " 24 9 0.813757 0.510043 [ 0. 1.4539553 -0.8808013] [0 0 0] 0.640122\n",
+ " 25 10 1 0.181381 [ 0. 1.4539553 -0.8808013] [0 0 0] 0.161278\n",
+ " 26 11 0.0334946 6.41702 [ 0. -1.4539553 -0.8808013] [0 0 0] 18.7311\n",
+ " 27 11 0.234727 1.55317 [ 0. -1.4539553 -0.8808013] [0 0 0] 2.82539\n",
+ " 28 11 0.813757 0.510043 [ 0. -1.4539553 -0.8808013] [0 0 0] 0.640122\n",
+ " 29 12 1 0.181381 [ 0. -1.4539553 -0.8808013] [0 0 0] 0.161278"
]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from mess import basisset\n",
+ "\n",
+ "basis_name = \"6-31g\"\n",
+ "basis = basisset(mol, basis_name)\n",
+ "basis"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "rlBloxuGo7a7"
+ },
+ "source": [
+ "We now have all the pieces to run our first electronic structure simulation.\n",
+ "This is done in two steps:\n",
+ "* build a Hamiltonian by selecting a treatment for the quantum-mechanical exchange and correlation.\n",
+ "* find the molecular orbital coefficients $C$ that minimise the energy of this Hamiltonian\n",
+ "\n",
+ "In the following we use the popular\n",
+ "[PBE exchange-correlation functional](https://doi.org/10.1103/PhysRevLett.77.3865)\n",
+ "of Density Functional Theory (DFT) to model the quantum-mechanical electron interactions.\n",
+ "\n",
+ "By default the minimisation routine will be compiled by JAX and executed on any available\n",
+ "hardware accelerator (e.g. GPU/TPU)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
},
+ "id": "qaJNzYL2o7a7",
+ "outputId": "60392863-5809-43d9-ef33-4b42a70fc749"
+ },
+ "outputs": [
{
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "SyXidjB-o7a8",
- "outputId": "38f21a87-4948-4dcf-fceb-894029f6ccca"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "139 ms ± 395 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
- ]
- }
- ],
- "source": [
- "%%timeit\n",
- "E, C, _ = minimise(H)\n",
- "E, C = E.block_until_ready(), C.block_until_ready()"
+ "data": {
+ "text/plain": [
+ "-76.29898732997808"
]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from mess import minimise, Hamiltonian\n",
+ "\n",
+ "H = Hamiltonian(basis, xc_method=\"pbe\")\n",
+ "E, C, sol = minimise(H)\n",
+ "float(E)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "FXoOtEVIo7a7"
+ },
+ "source": [
+ "We can visualise the electron density $\\rho(\\mathbf{r})$ from the solution we found.\n",
+ "The electron cloud is approximately mickey mouse head shaped to use the technical term."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 514
},
+ "id": "r0lgDmDio7a7",
+ "outputId": "281529f9-e930-47c7-b578-21bbb8898d7f"
+ },
+ "outputs": [
{
- "cell_type": "markdown",
- "metadata": {
- "id": "wzSb3QLAo7a8"
- },
- "source": [
- "Before we get carried away profiling, we first make sure that the MESS simulation agrees with a standard\n",
- "and well used DFT software package - [PySCF](https://pyscf.org/)"
+ "data": {
+ "application/3dmoljs_load.v0": "\n
3Dmol.js failed to load for some reason. Please check your browser console for error messages.
\n
\n",
+ "text/html": [
+ "\n",
+ "
3Dmol.js failed to load for some reason. Please check your browser console for error messages.
\n",
+ "
\n",
+ ""
]
+ },
+ "metadata": {},
+ "output_type": "display_data"
},
{
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "cpLKYsZVo7a8",
- "outputId": "9a2d5575-1c74-4695-e7e3-0dad95d2b121"
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/usr/local/lib/python3.10/dist-packages/pyscf/gto/mole.py:1286: UserWarning: Function mol.dumps drops attribute spin because it is not JSON-serializable\n",
- " warnings.warn(msg)\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "converged SCF energy = -76.2989811596524\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "-76.29898115965236"
- ]
- },
- "execution_count": 8,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "from mess.interop import to_pyscf\n",
- "from pyscf import dft, scf\n",
- "\n",
- "\n",
- "scf_mol = to_pyscf(mol, basis_name)\n",
- "s = dft.RKS(scf_mol, xc=\"pbe\")\n",
- "s.kernel()"
+ "data": {
+ "text/plain": [
+ ""
]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import py3Dmol\n",
+ "from mess.mesh import density, uniform_mesh\n",
+ "from mess.plot import plot_volume, plot_molecule\n",
+ "\n",
+ "view = py3Dmol.view()\n",
+ "plot_molecule(view, mol)\n",
+ "mesh = uniform_mesh()\n",
+ "rho = density(basis, mesh, basis.density_matrix(C))\n",
+ "plot_volume(view, rho, mesh.axes)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "1ipZwc3go7a7"
+ },
+ "source": [
+ "Performance is one of several goals of the MESS project so lets measure how long the\n",
+ "energy minimisation takes"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
},
+ "id": "kU8Xr34Ya9Dx",
+ "outputId": "eb67c335-d140-4911-f94c-61b992843d33"
+ },
+ "outputs": [
{
- "cell_type": "markdown",
- "metadata": {
- "id": "AK42Yofxo7a8"
- },
- "source": [
- "The calculated energies match!...lets open a can of worms and measure the performance of\n",
- " the PySCF energy minimisation"
- ]
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "134 ms ± 711 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
+ ]
+ }
+ ],
+ "source": [
+ "def mess_benchmark():\n",
+ " E, C, _ = minimise(H)\n",
+ " E, C = E.block_until_ready(), C.block_until_ready()\n",
+ "\n",
+ "\n",
+ "mess_time = %timeit -o mess_benchmark()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "wzSb3QLAo7a8"
+ },
+ "source": [
+ "Before we get carried away profiling, we first make sure that the MESS simulation agrees with a standard\n",
+ "and well used DFT software package - [PySCF](https://pyscf.org/)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
},
+ "id": "cpLKYsZVo7a8",
+ "outputId": "2a81e55e-c410-4108-a6ea-0c567aa83564"
+ },
+ "outputs": [
{
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "jqItCM_to7a8",
- "outputId": "4cebccdc-6245-49a5-b173-301089b7f449"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "converged SCF energy = -76.2989811596524\n",
- "converged SCF energy = -76.2989811596524\n",
- "converged SCF energy = -76.2989811596524\n",
- "converged SCF energy = -76.2989811596524\n",
- "converged SCF energy = -76.2989811596524\n",
- "converged SCF energy = -76.2989811596523\n",
- "converged SCF energy = -76.2989811596524\n",
- "converged SCF energy = -76.2989811596525\n",
- "2.08 s ± 180 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
- ]
- }
- ],
- "source": [
- "%%timeit\n",
- "s.kernel()"
- ]
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.10/dist-packages/pyscf/gto/mole.py:1286: UserWarning: Function mol.dumps drops attribute spin because it is not JSON-serializable\n",
+ " warnings.warn(msg)\n"
+ ]
},
{
- "cell_type": "markdown",
- "metadata": {
- "id": "cT5PXL_To7a8"
- },
- "source": [
- "We've measured a nearly 15X speedup...There are several gotchas of course that will be\n",
- "explored in due time...stay tuned for more!"
- ]
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "converged SCF energy = -76.2989811596524\n"
+ ]
},
{
- "cell_type": "markdown",
- "metadata": {
- "id": "NBakCcCOo7a9"
- },
- "source": [
- "## Hartree-Fock\n",
- "\n",
- "Hartree-Fock is a closely related method to the DFT solution found above that in MESS\n",
- "is selected by passing `xc_method=hfx`"
+ "data": {
+ "text/plain": [
+ "True"
]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import numpy as np\n",
+ "from mess.interop import to_pyscf\n",
+ "from pyscf import dft, scf\n",
+ "\n",
+ "\n",
+ "scf_mol = to_pyscf(mol, basis_name)\n",
+ "s = dft.RKS(scf_mol, xc=\"pbe\")\n",
+ "s.kernel()\n",
+ "np.allclose(s.energy_tot(), E)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "AK42Yofxo7a8"
+ },
+ "source": [
+ "The calculated energies match!...lets open a can of worms and measure the performance of\n",
+ " the PySCF energy minimisation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
},
+ "id": "jqItCM_to7a8",
+ "outputId": "5376dc0e-cd9b-4dac-cfe3-ad99d95455c3"
+ },
+ "outputs": [
{
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "O8y-XuFMo7a9",
- "outputId": "03de82c5-83e8-4f97-e41b-5287657f6505"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "-75.98417353837593"
- ]
- },
- "execution_count": 10,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "H = Hamiltonian(basis, xc_method=\"hfx\")\n",
- "E, C, sol = minimise(H)\n",
- "float(E)"
- ]
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "converged SCF energy = -76.2989811596524\n",
+ "converged SCF energy = -76.2989811596524\n",
+ "converged SCF energy = -76.2989811596524\n",
+ "converged SCF energy = -76.2989811596523\n",
+ "converged SCF energy = -76.2989811596524\n",
+ "converged SCF energy = -76.2989811596523\n",
+ "converged SCF energy = -76.2989811596524\n",
+ "converged SCF energy = -76.2989811596524\n",
+ "2.12 s ± 133 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
+ ]
+ }
+ ],
+ "source": [
+ "pyscf_time = %timeit -o s.kernel()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
},
+ "id": "1eScOBghbwsm",
+ "outputId": "cd787303-686c-4ca9-dc3e-7597495717a4"
+ },
+ "outputs": [
{
- "cell_type": "markdown",
- "metadata": {
- "id": "bI7L6g0Do7a9"
- },
- "source": [
- "The energy is a little higher than the DFT solution found earlier but at a significantly\n",
- "reduced computational cost"
+ "data": {
+ "text/plain": [
+ "15.896320044865424"
]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pyscf_time.average / mess_time.average"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "cT5PXL_To7a8"
+ },
+ "source": [
+ "We've measured a nearly 16X speedup...There are several gotchas of course that will be\n",
+ "explored in due time...stay tuned for more!"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "NBakCcCOo7a9"
+ },
+ "source": [
+ "## Hartree-Fock\n",
+ "\n",
+ "Hartree-Fock is a closely related method to the DFT solution found above that in MESS\n",
+ "is selected by passing `xc_method=hfx`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
},
+ "id": "O8y-XuFMo7a9",
+ "outputId": "efb10e00-3be6-4e87-fb4e-b1b9deb2d3b9"
+ },
+ "outputs": [
{
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "HS12lv2Ao7a9",
- "outputId": "e0bf34ad-8fdd-41c9-f442-72e317619fb2"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "20.6 ms ± 51.8 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
- ]
- }
- ],
- "source": [
- "%%timeit\n",
- "minimise(H)"
+ "data": {
+ "text/plain": [
+ "-75.9841721516878"
]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "H = Hamiltonian(basis, xc_method=\"hfx\")\n",
+ "E, C, sol = minimise(H)\n",
+ "float(E)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "bI7L6g0Do7a9"
+ },
+ "source": [
+ "The energy is a little higher than the DFT solution found earlier but at a significantly\n",
+ "reduced computational cost"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
},
+ "id": "HS12lv2Ao7a9",
+ "outputId": "dae30ab4-26e4-4855-d17b-bf89e99bc9df"
+ },
+ "outputs": [
{
- "cell_type": "markdown",
- "metadata": {
- "id": "PzrtmXnHo7a9"
- },
- "source": [
- "As another sanity check we make sure the calculated Hartree-Fock energy calculated by MESS agrees with PySCF"
- ]
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "21.6 ms ± 358 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
+ ]
+ }
+ ],
+ "source": [
+ "%%timeit\n",
+ "minimise(H)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "PzrtmXnHo7a9"
+ },
+ "source": [
+ "As another sanity check we make sure the calculated Hartree-Fock energy calculated by MESS agrees with PySCF"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
},
+ "id": "SV9pzl__o7a9",
+ "outputId": "9bbc38e7-2d65-4aca-c8e3-aaf426c7f847"
+ },
+ "outputs": [
{
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "SV9pzl__o7a9",
- "outputId": "c6f4047a-24a3-47eb-e463-ace1b5df4852"
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/usr/local/lib/python3.10/dist-packages/pyscf/gto/mole.py:1286: UserWarning: Function mol.dumps drops attribute spin because it is not JSON-serializable\n",
- " warnings.warn(msg)\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "converged SCF energy = -75.9841721516934\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "-75.98417215169343"
- ]
- },
- "execution_count": 12,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "s = scf.RHF(scf_mol)\n",
- "s.kernel()"
- ]
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.10/dist-packages/pyscf/gto/mole.py:1286: UserWarning: Function mol.dumps drops attribute spin because it is not JSON-serializable\n",
+ " warnings.warn(msg)\n"
+ ]
},
{
- "cell_type": "markdown",
- "metadata": {
- "id": "gozX_FLEo7a9"
- },
- "source": [
- "We hope you enjoyed your tour of MESS and welcome any feedback as\n",
- "[github issues](https://github.com/graphcore-research/mess/issues) where we can continue the discussion.\n",
- "\n",
- "\n",
- ":::{note}\n",
- "For reproducibility we record the default accelerator (if any) used by JAX and the CPU architecture used to execute this notebook.\n",
- ":::"
- ]
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "converged SCF energy = -75.9841721516933\n"
+ ]
},
{
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 35
- },
- "id": "Vl7iP8ojo7a-",
- "outputId": "2e8b81b7-309e-405f-c91d-3a7a7a67d713"
- },
- "outputs": [
- {
- "data": {
- "application/vnd.google.colaboratory.intrinsic+json": {
- "type": "string"
- },
- "text/plain": [
- "'NVIDIA A100-SXM4-40GB'"
- ]
- },
- "execution_count": 13,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "import jax\n",
- "\n",
- "jax.devices()[0].device_kind"
+ "data": {
+ "text/plain": [
+ "-75.98417215169334"
]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "s = scf.RHF(scf_mol)\n",
+ "s.kernel()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "gozX_FLEo7a9"
+ },
+ "source": [
+ "We hope you enjoyed your tour of MESS and welcome any feedback as\n",
+ "[github issues](https://github.com/graphcore-research/mess/issues) where we can continue the discussion.\n",
+ "\n",
+ "\n",
+ ":::{note}\n",
+ "For reproducibility we record the default accelerator (if any) used by JAX and the CPU architecture used to execute this notebook.\n",
+ ":::"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 35
},
+ "id": "Vl7iP8ojo7a-",
+ "outputId": "e6342599-0e5a-4d22-a54e-6fff9a61d0c6"
+ },
+ "outputs": [
{
- "cell_type": "code",
- "execution_count": 14,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "nRA2mRW5o7a-",
- "outputId": "11b6ad81-3b95-428a-e5e0-69297a79bc17"
+ "data": {
+ "application/vnd.google.colaboratory.intrinsic+json": {
+ "type": "string"
},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Architecture: x86_64\n",
- " CPU op-mode(s): 32-bit, 64-bit\n",
- " Address sizes: 46 bits physical, 48 bits virtual\n",
- " Byte Order: Little Endian\n",
- "CPU(s): 12\n",
- " On-line CPU(s) list: 0-11\n",
- "Vendor ID: GenuineIntel\n",
- " Model name: Intel(R) Xeon(R) CPU @ 2.20GHz\n",
- " CPU family: 6\n",
- " Model: 85\n",
- " Thread(s) per core: 2\n",
- " Core(s) per socket: 6\n",
- " Socket(s): 1\n",
- " Stepping: 7\n",
- " BogoMIPS: 4400.44\n",
- " Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clf\n",
- " lush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_\n",
- " good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fm\n",
- " a cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hyp\n",
- " ervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_\n",
- " enhanced fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx a\n",
- " vx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl \n",
- " xsaveopt xsavec xgetbv1 xsaves arat avx512_vnni md_clear arch_capabilities\n",
- "Virtualization features: \n",
- " Hypervisor vendor: KVM\n",
- " Virtualization type: full\n",
- "Caches (sum of all): \n",
- " L1d: 192 KiB (6 instances)\n",
- " L1i: 192 KiB (6 instances)\n",
- " L2: 6 MiB (6 instances)\n",
- " L3: 38.5 MiB (1 instance)\n",
- "NUMA: \n",
- " NUMA node(s): 1\n",
- " NUMA node0 CPU(s): 0-11\n",
- "Vulnerabilities: \n",
- " Gather data sampling: Not affected\n",
- " Itlb multihit: Not affected\n",
- " L1tf: Not affected\n",
- " Mds: Vulnerable; SMT Host state unknown\n",
- " Meltdown: Not affected\n",
- " Mmio stale data: Vulnerable\n",
- " Retbleed: Vulnerable\n",
- " Spec rstack overflow: Not affected\n",
- " Spec store bypass: Vulnerable\n",
- " Spectre v1: Vulnerable: __user pointer sanitization and usercopy barriers only; no swap\n",
- " gs barriers\n",
- " Spectre v2: Vulnerable, IBPB: disabled, STIBP: disabled, PBRSB-eIBRS: Vulnerable\n",
- " Srbds: Not affected\n",
- " Tsx async abort: Vulnerable\n"
- ]
- }
- ],
- "source": [
- "!lscpu"
+ "text/plain": [
+ "'NVIDIA A100-SXM4-40GB'"
]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
}
- ],
- "metadata": {
- "accelerator": "GPU",
+ ],
+ "source": [
+ "import jax\n",
+ "\n",
+ "jax.devices()[0].device_kind"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {
"colab": {
- "gpuType": "A100",
- "machine_shape": "hm",
- "provenance": []
- },
- "kernelspec": {
- "display_name": "Python 3",
- "name": "python3"
+ "base_uri": "https://localhost:8080/"
},
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.10.13"
+ "id": "nRA2mRW5o7a-",
+ "outputId": "20370952-618f-4bff-c387-a4e77f584c50"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Architecture: x86_64\n",
+ " CPU op-mode(s): 32-bit, 64-bit\n",
+ " Address sizes: 46 bits physical, 48 bits virtual\n",
+ " Byte Order: Little Endian\n",
+ "CPU(s): 12\n",
+ " On-line CPU(s) list: 0-11\n",
+ "Vendor ID: GenuineIntel\n",
+ " Model name: Intel(R) Xeon(R) CPU @ 2.20GHz\n",
+ " CPU family: 6\n",
+ " Model: 85\n",
+ " Thread(s) per core: 2\n",
+ " Core(s) per socket: 6\n",
+ " Socket(s): 1\n",
+ " Stepping: 7\n",
+ " BogoMIPS: 4400.39\n",
+ " Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clf\n",
+ " lush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_\n",
+ " good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fm\n",
+ " a cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hyp\n",
+ " ervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_\n",
+ " enhanced fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx a\n",
+ " vx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl \n",
+ " xsaveopt xsavec xgetbv1 xsaves arat avx512_vnni md_clear arch_capabilities\n",
+ "Virtualization features: \n",
+ " Hypervisor vendor: KVM\n",
+ " Virtualization type: full\n",
+ "Caches (sum of all): \n",
+ " L1d: 192 KiB (6 instances)\n",
+ " L1i: 192 KiB (6 instances)\n",
+ " L2: 6 MiB (6 instances)\n",
+ " L3: 38.5 MiB (1 instance)\n",
+ "NUMA: \n",
+ " NUMA node(s): 1\n",
+ " NUMA node0 CPU(s): 0-11\n",
+ "Vulnerabilities: \n",
+ " Gather data sampling: Not affected\n",
+ " Itlb multihit: Not affected\n",
+ " L1tf: Not affected\n",
+ " Mds: Vulnerable; SMT Host state unknown\n",
+ " Meltdown: Not affected\n",
+ " Mmio stale data: Vulnerable\n",
+ " Retbleed: Vulnerable\n",
+ " Spec rstack overflow: Not affected\n",
+ " Spec store bypass: Vulnerable\n",
+ " Spectre v1: Vulnerable: __user pointer sanitization and usercopy barriers only; no swap\n",
+ " gs barriers\n",
+ " Spectre v2: Vulnerable, IBPB: disabled, STIBP: disabled, PBRSB-eIBRS: Vulnerable\n",
+ " Srbds: Not affected\n",
+ " Tsx async abort: Vulnerable\n"
+ ]
}
+ ],
+ "source": [
+ "!lscpu"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "gpuType": "A100",
+ "machine_shape": "hm",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
},
- "nbformat": 4,
- "nbformat_minor": 0
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
}