From 72f81e56330e779d1a8c36952028e1fc6d4cb531 Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Thu, 14 Sep 2023 13:49:35 +0000 Subject: [PATCH 01/13] WIP: Implementation of THO integrals --- .github/workflows/unittest.yaml | 38 ++++++++++ pyscf_ipu/experimental/basis.py | 73 +++++++++++++++++++ pyscf_ipu/experimental/mesh.py | 38 ++++++++++ pyscf_ipu/experimental/orbital.py | 49 +++++++++++++ pyscf_ipu/experimental/primitive.py | 52 ++++++++++++++ pyscf_ipu/experimental/structure.py | 104 ++++++++++++++++++++++++++++ pyscf_ipu/experimental/types.py | 11 +++ pyscf_ipu/experimental/units.py | 12 ++++ requirements_core.txt | 2 + test/test_experimental.py | 45 ++++++++++++ 10 files changed, 424 insertions(+) create mode 100644 .github/workflows/unittest.yaml create mode 100644 pyscf_ipu/experimental/basis.py create mode 100644 pyscf_ipu/experimental/mesh.py create mode 100644 pyscf_ipu/experimental/orbital.py create mode 100644 pyscf_ipu/experimental/primitive.py create mode 100644 pyscf_ipu/experimental/structure.py create mode 100644 pyscf_ipu/experimental/types.py create mode 100644 pyscf_ipu/experimental/units.py create mode 100644 test/test_experimental.py diff --git a/.github/workflows/unittest.yaml b/.github/workflows/unittest.yaml new file mode 100644 index 0000000..67afe73 --- /dev/null +++ b/.github/workflows/unittest.yaml @@ -0,0 +1,38 @@ +name: unit tests +on: + pull_request: + push: + branches: [main] + +jobs: + pytest-container: + runs-on: ubuntu-latest + container: + image: graphcore/pytorch:3.2.0-ubuntu-20.04 + + steps: + - uses: actions/checkout@v3 + + - name : Install package dependencies + run: | + apt update -y + apt install git -y + + - name: Install requirements + run: | + pip install -U pip + pip install -r requirements_ipu.txt + pip install -r requirements_test.txt + pip install -e . + + - name: Log installed environment + run: | + python3 -m pip freeze + + - name: Run unit tests + env: + JAX_IPU_USE_MODEL: 1 + JAX_IPU_MODEL_NUM_TILES: 46 + JAX_PLATFORMS: cpu,ipu + run: | + pytest -s . diff --git a/pyscf_ipu/experimental/basis.py b/pyscf_ipu/experimental/basis.py new file mode 100644 index 0000000..eb7ebad --- /dev/null +++ b/pyscf_ipu/experimental/basis.py @@ -0,0 +1,73 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +from typing import Tuple + +import chex +import jax.numpy as jnp + +from .orbital import Orbital +from .structure import Structure +from .types import FloatN, FloatNx3, FloatNxM + + +@chex.dataclass +class Basis: + orbitals: Tuple[Orbital] + structure: Structure + + @property + def num_orbitals(self) -> int: + return len(self.orbitals) + + @property + def num_primitives(self) -> int: + return sum(ao.num_primitives for ao in self.orbitals) + + @property + def occupancy(self) -> FloatN: + # Assumes uncharged systems in restricted Kohn-Sham + occ = jnp.full(self.num_orbitals, 2.0) + mask = occ.cumsum() > self.structure.num_electrons + occ = occ.at[mask].set(0.0) + return occ + + def __call__(self, pos: FloatNx3) -> FloatNxM: + return jnp.hstack([o(pos) for o in self.orbitals]) + + +def basisset(structure: Structure, basis_name: str = "sto-3g"): + from basis_set_exchange import get_basis + from basis_set_exchange.sort import sort_basis + + LMN_MAP = { + 0: [(0, 0, 0)], + 1: [(1, 0, 0), (0, 1, 0), (0, 0, 1)], + 2: [(2, 0, 0), (1, 1, 0), (1, 0, 1), (0, 2, 0), (0, 1, 1), (0, 0, 2)], + } + + bse_basis = get_basis( + basis_name, + elements=structure.atomic_symbol, + uncontract_spdf=True, + uncontract_general=False, + ) + bse_basis = sort_basis(bse_basis)["elements"] + orbitals = [] + + for a in range(structure.num_atoms): + center = structure.position[a, :] + shells = bse_basis[str(structure.atomic_number[a])]["electron_shells"] + + for s in shells: + for lmn in LMN_MAP[s["angular_momentum"][0]]: + ao = Orbital.from_bse( + center=center, + alphas=jnp.array(s["exponents"], dtype=float), + lmn=jnp.array(lmn, dtype=jnp.int32), + coefficients=jnp.array(s["coefficients"], dtype=float), + ) + orbitals.append(ao) + + return Basis( + orbitals=orbitals, + structure=structure, + ) diff --git a/pyscf_ipu/experimental/mesh.py b/pyscf_ipu/experimental/mesh.py new file mode 100644 index 0000000..3af7cdc --- /dev/null +++ b/pyscf_ipu/experimental/mesh.py @@ -0,0 +1,38 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +from typing import Union, Tuple, Optional +import jax.numpy as jnp +from .basis import Basis +from .types import FloatNx3, FloatN, FloatNxN + + +def uniform_mesh( + n: Union[int, Tuple] = 50, b: Union[float, Tuple] = 10.0, ndim: int = 3 +): + if isinstance(n, int): + n = (n,) * ndim + + if isinstance(b, float): + b = (b,) * ndim + + if not isinstance(n, (tuple, list)): + raise ValueError("Expected an integer ") + + if len(n) != ndim: + raise ValueError("n must be a tuple with {ndim} elements") + + if len(b) != ndim: + raise ValueError("b must be a tuple with {ndim} elements") + + axes = [jnp.linspace(-bi, bi, ni) for bi, ni in zip(b, n)] + mesh = jnp.stack(jnp.meshgrid(*axes, indexing="ij"), axis=-1) + mesh = mesh.reshape(-1, ndim) + return mesh + + +def electron_density( + basis: Basis, mesh: FloatNx3, C: Optional[FloatNxN] = None +) -> FloatN: + C = jnp.eye(basis.num_orbitals) if C is None else C + orbitals = basis(mesh) @ C + density = jnp.sum(basis.occupancy * orbitals * orbitals, axis=-1) + return density diff --git a/pyscf_ipu/experimental/orbital.py b/pyscf_ipu/experimental/orbital.py new file mode 100644 index 0000000..518daaf --- /dev/null +++ b/pyscf_ipu/experimental/orbital.py @@ -0,0 +1,49 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. + +from typing import Tuple +from functools import partial + +import chex +import jax.numpy as jnp +from jax import tree_map, vmap + +from .primitive import Primitive, eval_primitive +from .types import FloatN, FloatNx3 + + +@chex.dataclass +class Orbital: + primitives: Tuple[Primitive] + coefficients: FloatN + + @property + def num_primitives(self) -> int: + return len(self.primitives) + + def __call__(self, pos: FloatNx3) -> FloatN: + assert pos.shape[-1] == 3, "pos must be have shape [N,3]" + + @partial(vmap, in_axes=(0, 0, None)) + def eval_orbital(p: Primitive, coef: float, pos: FloatNx3): + return coef * eval_primitive(p, pos) + + batch = tree_map(lambda *xs: jnp.stack(xs), *self.primitives) + out = jnp.sum(eval_orbital(batch, self.coefficients, pos), axis=0) + return out + + @staticmethod + def from_bse(center, alphas, lmn, coefficients): + coefficients = coefficients.reshape(-1) + assert len(coefficients) == len(alphas), "Expecting same size vectors!" + p = [Primitive(center=center, alpha=a, lmn=lmn) for a in alphas] + return Orbital(primitives=p, coefficients=coefficients) + + +def batch_orbitals(orbitals: Tuple[Orbital]): + primitives = [p for o in orbitals for p in o.primitives] + primitives = tree_map(lambda *xs: jnp.stack(xs), *primitives) + coefficients = jnp.concatenate([o.coefficients for o in orbitals]) + orbital_index = jnp.concatenate( + [i * jnp.ones(len(o), dtype=jnp.int32) for i, o in enumerate(orbitals)] + ) + return primitives, coefficients, orbital_index diff --git a/pyscf_ipu/experimental/primitive.py b/pyscf_ipu/experimental/primitive.py new file mode 100644 index 0000000..283ac9e --- /dev/null +++ b/pyscf_ipu/experimental/primitive.py @@ -0,0 +1,52 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +from typing import Optional + +import chex +import jax.numpy as jnp +from jax.scipy.special import gammaln + +from .types import Float3, FloatN, FloatNx3, Int3 + + +@chex.dataclass +class Primitive: + center: Float3 = jnp.zeros(3, dtype=jnp.float32) + alpha: float = 1.0 + lmn: Int3 = jnp.zeros(3, dtype=jnp.int32) + norm: Optional[float] = None + + def __post_init__(self): + if self.norm is None: + self.norm = normalize(self.lmn, self.alpha) + + @property + def angular_momentum(self) -> int: + return jnp.sum(self.lmn) + + def __call__(self, pos: FloatNx3) -> FloatN: + return eval_primitive(self, pos) + + +def normalize(lmn: Int3, alpha: float) -> float: + L = jnp.sum(lmn) + N = ((1 / 2) / alpha) ** (L + 3 / 2) + N *= jnp.exp(jnp.sum(gammaln(lmn + 1 / 2))) + return N**-0.5 + + +def product(a: Primitive, b: Primitive) -> Primitive: + alpha = a.alpha + b.alpha + center = (a.alpha * a.center + b.alpha * b.center) / alpha + lmn = a.lmn + b.lmn + c = a.norm * b.norm + Rab = a.center - b.center + c *= jnp.exp(-a.alpha * b.alpha / alpha * jnp.inner(Rab, Rab)) + return Primitive(center=center, alpha=alpha, lmn=lmn, norm=c) + + +def eval_primitive(p: Primitive, pos: FloatNx3) -> FloatN: + assert pos.shape[-1] == 3, "pos must be have shape [N,3]" + pos_translated = pos[:, jnp.newaxis] - p.center + v = p.norm * jnp.exp(-p.alpha * jnp.sum(pos_translated**2, axis=-1)) + v *= jnp.prod(pos_translated**p.lmn, axis=-1) + return v diff --git a/pyscf_ipu/experimental/structure.py b/pyscf_ipu/experimental/structure.py new file mode 100644 index 0000000..20d9fd2 --- /dev/null +++ b/pyscf_ipu/experimental/structure.py @@ -0,0 +1,104 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +from typing import List + +import chex +import numpy as np +from periodictable import elements +from py3Dmol import view +from pyscf import gto + +from .types import FloatNx3, IntN +from .units import to_angstrom, to_bohr + + +@chex.dataclass +class Structure: + atomic_number: IntN + position: FloatNx3 + is_bohr: bool = True + + def __post_init__(self): + if not self.is_bohr: + self.position = to_bohr(self.position) + + @property + def num_atoms(self) -> int: + return len(self.atomic_number) + + @property + def atomic_symbol(self) -> List[str]: + return [elements[z].symbol for z in self.atomic_number] + + @property + def num_electrons(self) -> int: + return np.sum(self.atomic_number) + + def to_xyz(self) -> str: + xyz = f"{self.num_atoms}\n\n" + sym = self.atomic_symbol + pos = to_angstrom(self.position) + + for i in range(self.num_atoms): + r = np.array2string(pos[i, :], separator="\t")[1:-1] + xyz += f"{sym[i]}\t{r}\n" + + return xyz + + def view(self) -> "view": + return view(data=self.to_xyz(), style={"stick": {"radius": 0.06}}) + + +def to_pyscf( + structure: Structure, basis_name: str = "sto-3g", unit: str = "Bohr" +) -> "gto.Mole": + mol = gto.Mole(unit=unit, spin=structure.num_electrons % 2, cart=True) + mol.atom = [ + (symbol, pos) + for symbol, pos in zip(structure.atomic_symbol, structure.position) + ] + mol.basis = basis_name + mol.build(unit=unit) + return mol + + +def water(): + r"""Single water molecule + Structure of single water molecule calculated with DFT using B3LYP + functional and 6-31+G** basis set """ + return Structure( + atomic_number=np.array([8, 1, 1]), + position=np.array( + [ + [0.0000, 0.0000, 0.1165], + [0.0000, 0.7694, -0.4661], + [0.0000, -0.7694, -0.4661], + ] + ), + is_bohr=False, + ) + + +def benzene(): + r"""Benzene ring + Structure of benzene ring calculated with DFT using B3LYP functional + and 6-31+G** basis set """ + return Structure( + atomic_number=np.array([6, 6, 6, 6, 6, 6, 1, 1, 1, 1, 1, 1]), + position=np.array( + [ + [0.0000000, 1.3983460, 0.0000000], + [1.2110030, 0.6991730, 0.0000000], + [1.2110030, -0.6991730, 0.0000000], + [0.0000000, -1.3983460, 0.0000000], + [-1.211003, -0.699173, 0.0000000], + [-1.211003, 0.6991730, 0.0000000], + [0.0000000, 2.4847510, 0.0000000], + [2.1518570, 1.2423750, 0.0000000], + [2.1518570, -1.2423750, 0.0000000], + [0.0000000, -2.4847510, 0.0000000], + [-2.151857, -1.242375, 0.0000000], + [-2.151857, 1.2423750, 0.0000000], + ] + ), + is_bohr=False, + ) diff --git a/pyscf_ipu/experimental/types.py b/pyscf_ipu/experimental/types.py new file mode 100644 index 0000000..b8450db --- /dev/null +++ b/pyscf_ipu/experimental/types.py @@ -0,0 +1,11 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +from jaxtyping import Float, Int, Array + + +Float3 = Float[Array, "3"] +FloatNx3 = Float[Array, "N 3"] +FloatN = Float[Array, "N"] +FloatNxN = Float[Array, "N N"] +FloatNxM = Float[Array, "N M"] +Int3 = Int[Array, "3"] +IntN = Int[Array, "N"] diff --git a/pyscf_ipu/experimental/units.py b/pyscf_ipu/experimental/units.py new file mode 100644 index 0000000..4099f40 --- /dev/null +++ b/pyscf_ipu/experimental/units.py @@ -0,0 +1,12 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +from jaxtyping import Array + +BOHR_PER_ANGSTROM = 0.529177210903 + + +def to_angstrom(bohr_value: Array) -> Array: + return bohr_value / BOHR_PER_ANGSTROM + + +def to_bohr(angstrom_value: Array) -> Array: + return angstrom_value * BOHR_PER_ANGSTROM diff --git a/requirements_core.txt b/requirements_core.txt index 17e5ef5..906c37f 100644 --- a/requirements_core.txt +++ b/requirements_core.txt @@ -21,6 +21,8 @@ jsonargparse[all] mogli imageio[ffmpeg] py3Dmol +basis-set-exchange +periodictable # silence warnings about setuptools + numpy setuptools < 60.0 diff --git a/test/test_experimental.py b/test/test_experimental.py new file mode 100644 index 0000000..f923646 --- /dev/null +++ b/test/test_experimental.py @@ -0,0 +1,45 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +import pytest +import numpy as np +import jax.numpy as jnp +from numpy.testing import assert_allclose + +from pyscf_ipu.experimental.basis import basisset +from pyscf_ipu.experimental.structure import water, to_pyscf +from pyscf_ipu.experimental.mesh import uniform_mesh, electron_density + + +@pytest.mark.parametrize("basis_name", ["sto-3g", "6-31g**"]) +def test_to_pyscf(basis_name): + mol = water() + basis = basisset(mol, basis_name) + pyscf_mol = to_pyscf(mol, basis_name) + assert basis.num_orbitals == pyscf_mol.nao + + +@pytest.mark.parametrize("basis_name", ["sto-3g", "6-31+g"]) +def test_gto(basis_name): + from pyscf.dft.numint import eval_rho + + # Atomic orbitals + structure = water() + basis = basisset(structure, basis_name) + mesh = uniform_mesh() + actual = basis(mesh) + + mol = to_pyscf(structure, basis_name) + expect_ao = mol.eval_gto("GTOval_cart", np.asarray(mesh)) + assert_allclose(actual, expect_ao, atol=1e-6) + + # Molecular orbitals + mf = mol.KS() + mf.kernel() + C = jnp.array(mf.mo_coeff, dtype=jnp.float32) + actual = basis.occupancy * C @ C.T + expect = jnp.array(mf.make_rdm1(), dtype=jnp.float32) + assert_allclose(actual, expect, atol=1e-6) + + # Electron density + actual = electron_density(basis, mesh, C) + expect = eval_rho(mol, expect_ao, mf.make_rdm1(), "lda") + assert_allclose(actual, expect, atol=1e-6) From dbcd13070735a833ed4f28b151cc6a6b51f046af Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Thu, 14 Sep 2023 15:03:31 +0000 Subject: [PATCH 02/13] adding impl --- pyscf_ipu/experimental/integrals.py | 341 ++++++++++++++++++++++++++++ pyscf_ipu/experimental/orbital.py | 5 +- test/test_experimental.py | 180 ++++++++++++++- 3 files changed, 521 insertions(+), 5 deletions(-) create mode 100644 pyscf_ipu/experimental/integrals.py diff --git a/pyscf_ipu/experimental/integrals.py b/pyscf_ipu/experimental/integrals.py new file mode 100644 index 0000000..8aa895a --- /dev/null +++ b/pyscf_ipu/experimental/integrals.py @@ -0,0 +1,341 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +from dataclasses import asdict +from functools import partial +from itertools import product as cartesian_product +from typing import Callable + +import jax.numpy as jnp +import numpy as np +from jax import lax, vmap, jit, tree_map +from jax.ops import segment_sum +from jax.scipy.special import gammainc, gammaln + +from .basis import Basis +from .orbital import batch_orbitals +from .primitive import Primitive, product +from .types import IntN, FloatN, FloatNxN, Float3, FloatNx3 + +LMAX = 4 + +""" +Special functions used in integral evaluation +""" + + +def factorial(n: IntN, nmax: int = LMAX) -> IntN: + def body_fun(i, val): + return val * jnp.where(i <= n, i, 1) + + return lax.fori_loop(1, nmax + 1, body_fun, jnp.ones_like(n)) + + +def factorial2(n: IntN, nmax: int = 2 * LMAX) -> IntN: + def body_fun(i, val): + return val * jnp.where((i <= n) & (n % 2 == i % 2), i, 1) + + return lax.fori_loop(1, nmax + 1, body_fun, jnp.ones_like(n)) + + +def binom(x: IntN, y: IntN, nmax: int = LMAX) -> IntN: + bang = partial(factorial, nmax=nmax) + c = x * bang(x - 1) / (bang(y) * bang(x - y)) + return jnp.where(x == y, 1, c) + + +def gammanu(nu: IntN, t: FloatN, epsilon: float = 1e-10) -> FloatN: + """ + eq 2.11 from THO but simplified using SymPy and converted to jax + + t, u = symbols("t u", real=True, positive=True) + nu = Symbol("nu", integer=True, nonnegative=True) + + expr = simplify(integrate(u ** (2 * nu) * exp(-t * u**2), (u, 0, 1))) + f = lambdify((nu, t), expr, modules="scipy") + ?f + + We evaulate this in log-space to avoid overflow/nan + """ + t = jnp.maximum(t, epsilon) + x = nu + 0.5 + gn = jnp.log(0.5) - x * jnp.log(t) + jnp.log(gammainc(x, t)) + gammaln(x) + return jnp.exp(gn) + + +@partial(vmap, in_axes=(0, None, None, None, None)) +def binom_factor(s: IntN, i: int, j: int, pa: float, pb: float): + """ + Eq. 15 from Augspurger JD, Dykstra CE. General quantum mechanical operators. An + open-ended approach for one-electron integrals with Gaussian bases. Journal of + computational chemistry. 1990 Jan;11(1):105-11. + + """ + + def term(t): + return binom(i, s - t) * binom(j, t) * pa ** (i - s + t) * pb ** (j - t) + + def body_fun(t, val): + mask = (t <= s) & (t >= (s - i)) & (t <= j) + return val + jnp.where(mask, term(t), 0.0) + + return lax.fori_loop(0, LMAX + 1, body_fun, 0.0) + + +@partial(vmap, in_axes=(0, 0, 0, 0, None)) +def overlap_axis(i: int, j: int, pa: int, pb: int, alpha: float) -> float: + ii = jnp.arange(LMAX + 1) + out = binom_factor(2 * ii, i, j, pa, pb) + out *= factorial2(2 * ii - 1) / (2 * alpha) ** ii + mask = ii <= jnp.floor_divide(i + j, 2) + out = jnp.where(mask, out, 0.0) + return jnp.sum(out) + + +def overlap_basis(b: Basis) -> FloatNxN: + return integrate(b, vmap_overlap_primitives) + + +def integrate(b: Basis, primitive_op: Callable) -> FloatNxN: + def take_primitives(indices): + p = tree_map(lambda x: jnp.take(x, indices, axis=0), primitives) + c = jnp.take(coefficients, indices) + return p, c + + primitives, coefficients, orbital_index = batch_orbitals(b.orbitals) + ii, jj = jnp.triu_indices(b.num_primitives) + lhs, cl = take_primitives(ii.reshape(-1)) + rhs, cr = take_primitives(jj.reshape(-1)) + aij = cl * cr * primitive_op(lhs, rhs) + A = jnp.zeros((b.num_primitives, b.num_primitives)) + A = A.at[ii, jj].set(aij) + A = A + A.T - jnp.diag(jnp.diag(A)) + index = orbital_index.reshape(1, -1) + return segment_sum(segment_sum(A, index).T, index) + + +def _overlap_primitives(a: Primitive, b: Primitive) -> float: + p = product(a, b) + pa = p.center - a.center + pb = p.center - b.center + out = jnp.power(jnp.pi / p.alpha, 1.5) * p.norm + out *= jnp.prod(overlap_axis(a.lmn, b.lmn, pa, pb, p.alpha)) + return out + + +def _kinetic_primitives(a: Primitive, b: Primitive) -> float: + t0 = b.alpha * (2 * jnp.sum(b.lmn) + 3) * _overlap_primitives(a, b) + + def offset_qn(ax: int, offset: int): + lmn = b.lmn.at[ax].add(offset) + return Primitive(**{**asdict(b), "lmn": lmn}) + + axes = jnp.arange(3) + b1 = vmap(offset_qn, (0, None))(axes, 2) + t1 = jnp.sum(vmap(_overlap_primitives, (None, 0))(a, b1)) + + b2 = vmap(offset_qn, (0, None))(axes, -2) + t2 = jnp.sum(b.lmn * (b.lmn - 1) * vmap(_overlap_primitives, (None, 0))(a, b2)) + return t0 - 2.0 * b.alpha**2 * t1 - 0.5 * t2 + + +def kinetic_basis(b: Basis) -> FloatNxN: + return integrate(b, vmap_kinetic_primitives) + + +def build_gindex(): + vals = [ + (i, r, u) + for i in range(LMAX + 1) + for r in range(i // 2 + 1) + for u in range((i - 2 * r) // 2 + 1) + ] + i, r, u = jnp.array(vals).T + return i, r, u + + +def _nuclear_primitives(a: Primitive, b: Primitive, c: Float3): + p = product(a, b) + pa = p.center - a.center + pb = p.center - b.center + pc = p.center - c + epsilon = 1.0 / (4.0 * p.alpha) + + @vmap + def g_term(l1, l2, pa, pb, cp): + i, r, u = build_gindex() + index = i - 2 * r - u + g = ( + jnp.power(-1, i + u) + * binom_factor(i, l1, l2, pa, pb) + * factorial(i) + * jnp.power(cp, index - u) + * jnp.power(epsilon, r + u) + ) / (factorial(r) * factorial(u) * factorial(index - u)) + + g = jnp.where(index <= l1 + l2, g, 0.0) + return jnp.zeros(LMAX + 1).at[index].add(g) + + Gi, Gj, Gk = g_term(a.lmn, b.lmn, pa, pb, pc) + + ijk = jnp.arange(LMAX + 1) + nu = ( + ijk[:, jnp.newaxis, jnp.newaxis] + + ijk[jnp.newaxis, :, jnp.newaxis] + + ijk[jnp.newaxis, jnp.newaxis, :] + ) + + W = ( + Gi[:, jnp.newaxis, jnp.newaxis] + * Gj[jnp.newaxis, :, jnp.newaxis] + * Gk[jnp.newaxis, jnp.newaxis, :] + * gammanu(nu, p.alpha * jnp.inner(pc, pc)) + ) + + return -2.0 * jnp.pi / p.alpha * p.norm * jnp.sum(W) + + +overlap_primitives = jit(_overlap_primitives) +kinetic_primitives = jit(_kinetic_primitives) +nuclear_primitives = jit(_nuclear_primitives) + +vmap_overlap_primitives = jit(vmap(_overlap_primitives)) +vmap_kinetic_primitives = jit(vmap(_kinetic_primitives)) +vmap_nuclear_primitives = jit(vmap(_nuclear_primitives)) + + +@partial(vmap, in_axes=(None, 0, 0)) +def nuclear_basis(b: Basis, c: FloatNx3, z: FloatN) -> FloatNxN: + op = partial(_nuclear_primitives, c=c) + op = vmap(op) + op = jit(op) + return z * integrate(b, op) + + +def build_cindex(): + vals = [ + (i1, i2, r1, r2, u) + for i1 in range(2 * LMAX + 1) + for i2 in range(2 * LMAX + 1) + for r1 in range(i1 // 2 + 1) + for r2 in range(i2 // 2 + 1) + for u in range((i1 + i2) // 2 - r1 - r2 + 1) + ] + i1, i2, r1, r2, u = jnp.array(vals).T + return i1, i2, r1, r2, u + + +def _eri_primitives(a: Primitive, b: Primitive, c: Primitive, d: Primitive) -> float: + p = product(a, b) + q = product(c, d) + pa = p.center - a.center + pb = p.center - b.center + qc = q.center - c.center + qd = q.center - d.center + qp = q.center - p.center + delta = 1 / (4.0 * p.alpha) + 1 / (4.0 * q.alpha) + + def H(l1, l2, a, b, i, r, gamma): + # Note this should match THO Eq 3.5 but that seems to incorrectly show a + # 1/(4 gamma) ^(i- 2r) term which is inconsistent with Eq 2.22. + # Using (4 gamma)^(r - i) matches the reported expressions for H_L + u = factorial(i) * binom_factor(i, l1, l2, a, b) + v = factorial(r) * factorial(i - 2 * r) * (4 * gamma) ** (i - r) + return u / v + + def c_term(la, lb, lc, ld, pa, pb, qc, qd, qp): + # THO Eq 2.22 and 3.4 + i1, i2, r1, r2, u = build_cindex() + h = H(la, lb, pa, pb, i1, r1, p.alpha) * H(lc, ld, qc, qd, i2, r2, q.alpha) + index = i1 + i2 - 2 * (r1 + r2) - u + x = (-1) ** (i2 + u) * factorial(index + u) * qp ** (index - u) + y = factorial(u) * factorial(index - u) * delta**index + c = h * x / y + + mask = (i1 <= (la + lb)) & (i2 <= (lc + ld)) + c = jnp.where(mask, c, 0.0) + return segment_sum(c, index, num_segments=4 * LMAX + 1) + + # Manual vmap over cartesian axes (x, y, z) as ran into possible bug. + args = [a.lmn, b.lmn, c.lmn, d.lmn, pa, pb, qc, qd, qp] + Ci, Cj, Ck = [c_term(*[v.at[i].get() for v in args]) for i in range(3)] + + ijk = jnp.arange(4 * LMAX + 1) + nu = ( + ijk[:, jnp.newaxis, jnp.newaxis] + + ijk[jnp.newaxis, :, jnp.newaxis] + + ijk[jnp.newaxis, jnp.newaxis, :] + ) + + W = ( + Ci[:, jnp.newaxis, jnp.newaxis] + * Cj[jnp.newaxis, :, jnp.newaxis] + * Ck[jnp.newaxis, jnp.newaxis, :] + * gammanu(nu, jnp.inner(qp, qp) / (4.0 * delta)) + ) + + return ( + 2.0 + * jnp.pi**2 + / (p.alpha * q.alpha) + * jnp.sqrt(jnp.pi / (p.alpha + q.alpha)) + * p.norm + * q.norm + * jnp.sum(W) + ) + + +eri_primitives = jit(_eri_primitives) +vmap_eri_primitives = jit(vmap(_eri_primitives)) + + +def gen_ijkl(n: int): + """ + adapted from four-index transformations by S Wilson pg 257 + """ + + def lmax(i: int, j: int, k: int): + return j if i == k else k + + for idx in range(n): + for jdx in range(idx + 1): + for kdx in range(idx + 1): + for ldx in range(lmax(idx, jdx, kdx) + 1): + yield idx, jdx, kdx, ldx + + +def eri_basis_sparse(b: Basis): + indices = [] + batch = [] + offset = np.cumsum([o.num_primitives for o in b.orbitals]) + offset = np.insert(offset, 0, 0) + + for count, idx in enumerate(gen_ijkl(b.num_orbitals)): + mesh = [range(offset[i], offset[i + 1]) for i in idx] + indices += list(cartesian_product(*mesh)) + batch += [count] * (len(indices) - len(batch)) + + indices = jnp.array(indices, dtype=jnp.int32).T + batch = jnp.array(batch, dtype=jnp.int32) + primitives, coefficients, _ = batch_orbitals(b.orbitals) + cijkl = jnp.stack([jnp.take(coefficients, idx) for idx in indices]).prod(axis=0) + pijkl = [ + tree_map(lambda x: jnp.take(x, idx, axis=0), primitives) for idx in indices + ] + eris = cijkl * vmap_eri_primitives(*pijkl) + return segment_sum(eris, batch) + + +def eri_basis(b: Basis): + unique_eris = eri_basis_sparse(b) + ii, jj, kk, ll = jnp.array(list(gen_ijkl(b.num_orbitals)), dtype=jnp.int32).T + + # Apply 8x permutation symmetry to build dense ERI from sparse ERI. + eri_dense = jnp.empty((b.num_orbitals,) * 4, dtype=jnp.float32) + eri_dense = eri_dense.at[ii, jj, kk, ll].set(unique_eris) + eri_dense = eri_dense.at[ii, jj, ll, kk].set(unique_eris) + eri_dense = eri_dense.at[jj, ii, kk, ll].set(unique_eris) + eri_dense = eri_dense.at[jj, ii, ll, kk].set(unique_eris) + eri_dense = eri_dense.at[kk, ll, ii, jj].set(unique_eris) + eri_dense = eri_dense.at[kk, ll, jj, ii].set(unique_eris) + eri_dense = eri_dense.at[ll, kk, ii, jj].set(unique_eris) + eri_dense = eri_dense.at[ll, kk, jj, ii].set(unique_eris) + return eri_dense diff --git a/pyscf_ipu/experimental/orbital.py b/pyscf_ipu/experimental/orbital.py index 518daaf..2cb4340 100644 --- a/pyscf_ipu/experimental/orbital.py +++ b/pyscf_ipu/experimental/orbital.py @@ -44,6 +44,9 @@ def batch_orbitals(orbitals: Tuple[Orbital]): primitives = tree_map(lambda *xs: jnp.stack(xs), *primitives) coefficients = jnp.concatenate([o.coefficients for o in orbitals]) orbital_index = jnp.concatenate( - [i * jnp.ones(len(o), dtype=jnp.int32) for i, o in enumerate(orbitals)] + [ + i * jnp.ones(o.num_primitives, dtype=jnp.int32) + for i, o in enumerate(orbitals) + ] ) return primitives, coefficients, orbital_index diff --git a/test/test_experimental.py b/test/test_experimental.py index f923646..7a7acca 100644 --- a/test/test_experimental.py +++ b/test/test_experimental.py @@ -1,12 +1,25 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. -import pytest -import numpy as np import jax.numpy as jnp +import numpy as np +import pytest +from jax import tree_map, vmap from numpy.testing import assert_allclose from pyscf_ipu.experimental.basis import basisset -from pyscf_ipu.experimental.structure import water, to_pyscf -from pyscf_ipu.experimental.mesh import uniform_mesh, electron_density +from pyscf_ipu.experimental.integrals import ( + overlap_primitives, + nuclear_basis, + overlap_basis, + nuclear_primitives, + eri_basis, + eri_basis_sparse, + eri_primitives, + kinetic_primitives, + kinetic_basis, +) +from pyscf_ipu.experimental.mesh import electron_density, uniform_mesh +from pyscf_ipu.experimental.primitive import Primitive +from pyscf_ipu.experimental.structure import to_pyscf, water, Structure @pytest.mark.parametrize("basis_name", ["sto-3g", "6-31g**"]) @@ -43,3 +56,162 @@ def test_gto(basis_name): actual = electron_density(basis, mesh, C) expect = eval_rho(mol, expect_ao, mf.make_rdm1(), "lda") assert_allclose(actual, expect, atol=1e-6) + + +def test_overlap(): + # Exercise 3.21 of "Modern quantum chemistry: introduction to advanced + # electronic structure theory."" by Szabo and Ostlund + alpha = 0.270950 * 1.24 * 1.24 + a = Primitive(alpha=alpha) + b = Primitive(alpha=alpha, center=jnp.array([1.4, 0.0, 0.0])) + assert_allclose(overlap_primitives(a, a), 1.0, atol=1e-5) + assert_allclose(overlap_primitives(b, b), 1.0, atol=1e-5) + assert_allclose(overlap_primitives(b, a), 0.6648, atol=1e-5) + + +@pytest.mark.parametrize("basis_name", ["sto-3g", "6-31+g", "6-31+g*"]) +def test_water_overlap(basis_name): + basis = basisset(water(), basis_name) + actual_overlap = overlap_basis(basis) + + # Note: PySCF doesn't appear to normalise d basis functions in cartesian basis + expect_overlap = to_pyscf(water(), basis_name=basis_name).intor("int1e_ovlp_cart") + n = 1 / np.sqrt(np.diagonal(expect_overlap)) + expect_overlap = n[:, None] * n[None, :] * expect_overlap + assert_allclose(actual_overlap, expect_overlap, atol=1e-6) + + +def test_kinetic(): + # PyQuante test case for kinetic primitive integral + p = Primitive() + assert_allclose(kinetic_primitives(p, p), 1.5, atol=1e-6) + + # Reproduce the kinetic energy matrix for H2 using STO-3G basis set + # See equation 3.230 of "Modern quantum chemistry: introduction to advanced + # electronic structure theory."" by Szabo and Ostlund + h2 = Structure( + atomic_number=np.array([1, 1]), + position=np.array([[0.0, 0.0, 0.0], [1.4, 0.0, 0.0]]), + ) + basis = basisset(h2, "sto-3g") + actual = kinetic_basis(basis) + expect = np.array([[0.7600, 0.2365], [0.2365, 0.7600]]) + assert_allclose(actual, expect, atol=1e-4) + + +@pytest.mark.parametrize( + "basis_name", + [ + "sto-3g", + "6-31+g", + pytest.param( + "6-31+g*", marks=pytest.mark.xfail(reason="Cartesian norm problem?") + ), + ], +) +def test_water_kinetic(basis_name): + basis = basisset(water(), basis_name) + actual = kinetic_basis(basis) + + expect = to_pyscf(water(), basis_name=basis_name).intor("int1e_kin_cart") + assert_allclose(actual, expect, atol=1e-4) + + +def test_nuclear(): + # PyQuante test case for nuclear attraction integral + p = Primitive() + c = jnp.zeros(3) + assert_allclose(nuclear_primitives(p, p, c), -1.595769, atol=1e-5) + + # Reproduce the nuclear attraction matrix for H2 using STO-3G basis set + # See equation 3.231 and 3.232 of Szabo and Ostlund + h2 = Structure( + atomic_number=np.array([1, 1]), + position=np.array([[0.0, 0.0, 0.0], [1.4, 0.0, 0.0]]), + ) + basis = basisset(h2, "sto-3g") + actual = nuclear_basis(basis, h2.position, h2.atomic_number) + expect = np.array( + [ + [[-1.2266, -0.5974], [-0.5974, -0.6538]], + [[-0.6538, -0.5974], [-0.5974, -1.2266]], + ] + ) + + assert_allclose(actual, expect, atol=1e-4) + + +def test_water_nuclear(): + basis_name = "sto-3g" + h2o = water() + basis = basisset(h2o, basis_name) + actual = nuclear_basis(basis, h2o.position, h2o.atomic_number).sum(axis=0) + expect = to_pyscf(h2o, basis_name=basis_name).intor("int1e_nuc_cart") + assert_allclose(actual, expect, atol=1e-4) + + +def eri_orbitals(orbitals): + def take(orbital, index): + p = tree_map(lambda *xs: jnp.stack(xs), *orbital.primitives) + p = tree_map(lambda x: jnp.take(x, index, axis=0), p) + c = jnp.take(orbital.coefficients, index) + return p, c + + indices = [jnp.arange(o.num_primitives) for o in orbitals] + indices = [i.reshape(-1) for i in jnp.meshgrid(*indices)] + prim, coef = zip(*[take(o, i) for o, i in zip(orbitals, indices)]) + return jnp.sum(jnp.prod(jnp.stack(coef), axis=0) * vmap(eri_primitives)(*prim)) + + +def test_eri(): + # PyQuante test cases for ERI + a, b, c, d = [Primitive()] * 4 + assert_allclose(eri_primitives(a, b, c, d), 1.128379, atol=1e-5) + + c, d = [Primitive(lmn=jnp.array([1, 0, 0]))] * 2 + assert_allclose(eri_primitives(a, b, c, d), 0.940316, atol=1e-5) + + # H2 molecule in sto-3g: See equation 3.235 of Szabo and Ostlund + h2 = Structure( + atomic_number=np.array([1, 1]), + position=np.array([[0.0, 0.0, 0.0], [1.4, 0.0, 0.0]]), + ) + basis = basisset(h2, "sto-3g") + indices = [(0, 0, 0, 0), (0, 0, 1, 1), (1, 0, 0, 0), (1, 0, 1, 0)] + expected = [0.7746, 0.5697, 0.4441, 0.2970] + + for ijkl, expect in zip(indices, expected): + actual = eri_orbitals([basis.orbitals[aoid] for aoid in ijkl]) + assert_allclose(actual, expect, atol=1e-4) + + +def test_eri_basis(): + # H2 molecule in sto-3g: See equation 3.235 of Szabo and Ostlund + h2 = Structure( + atomic_number=np.array([1, 1]), + position=np.array([[0.0, 0.0, 0.0], [1.4, 0.0, 0.0]]), + ) + basis = basisset(h2, "sto-3g") + + actual = eri_basis(basis) + expect = np.empty((2, 2, 2, 2), dtype=np.float32) + expect[0, 0, 0, 0] = expect[1, 1, 1, 1] = 0.7746 + expect[0, 0, 1, 1] = expect[1, 1, 0, 0] = 0.5697 + expect[1, 0, 0, 0] = expect[0, 0, 0, 1] = 0.4441 + expect[0, 1, 0, 0] = expect[0, 0, 1, 0] = 0.4441 + expect[0, 1, 1, 1] = expect[1, 1, 1, 0] = 0.4441 + expect[1, 0, 1, 1] = expect[1, 1, 0, 1] = 0.4441 + expect[1, 0, 1, 0] = expect[0, 1, 1, 0] = 0.2970 + expect[0, 1, 0, 1] = expect[1, 0, 0, 1] = 0.2970 + assert_allclose(actual, expect, atol=1e-4) + + +@pytest.mark.parametrize("sparse", [True, False]) +def test_water_eri(sparse): + basis_name = "sto-3g" + h2o = water() + basis = basisset(h2o, basis_name) + actual = eri_basis_sparse(basis) if sparse else eri_basis(basis) + aosym = "s8" if sparse else "s1" + expect = to_pyscf(h2o, basis_name=basis_name).intor("int2e_cart", aosym=aosym) + assert_allclose(actual, expect, atol=1e-4) From e863390645417230086bed85bfdf0dc44b9d4731 Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Thu, 14 Sep 2023 15:07:55 +0000 Subject: [PATCH 03/13] adding ipu test cases --- pyscf_ipu/experimental/device.py | 29 +++++++++++++++++++++++ test/test_experimental.py | 40 ++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+) create mode 100644 pyscf_ipu/experimental/device.py diff --git a/pyscf_ipu/experimental/device.py b/pyscf_ipu/experimental/device.py new file mode 100644 index 0000000..110928a --- /dev/null +++ b/pyscf_ipu/experimental/device.py @@ -0,0 +1,29 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +from functools import partial, wraps +import numpy as np +from jax import devices, jit + + +def has_ipu() -> bool: + try: + return len(devices("ipu")) > 0 + except RuntimeError: + pass + + return False + + +ipu_jit = partial(jit, backend="ipu") + + +def ipu_func(func): + @wraps(func) + def wrapper(*args, **kwargs): + outputs = ipu_jit(func)(*args, **kwargs) + + if not isinstance(outputs, tuple): + return np.asarray(outputs) + + return [np.asarray(o) for o in outputs] + + return wrapper diff --git a/test/test_experimental.py b/test/test_experimental.py index 7a7acca..165fdd3 100644 --- a/test/test_experimental.py +++ b/test/test_experimental.py @@ -20,6 +20,7 @@ from pyscf_ipu.experimental.mesh import electron_density, uniform_mesh from pyscf_ipu.experimental.primitive import Primitive from pyscf_ipu.experimental.structure import to_pyscf, water, Structure +from pyscf_ipu.experimental.device import has_ipu, ipu_func @pytest.mark.parametrize("basis_name", ["sto-3g", "6-31g**"]) @@ -215,3 +216,42 @@ def test_water_eri(sparse): aosym = "s8" if sparse else "s1" expect = to_pyscf(h2o, basis_name=basis_name).intor("int2e_cart", aosym=aosym) assert_allclose(actual, expect, atol=1e-4) + + +@pytest.mark.skipif(not has_ipu(), reason="Skipping ipu test!") +def test_ipu_overlap(): + from pyscf_ipu.experimental.integrals import _overlap_primitives + + a, b = [Primitive()] * 2 + actual = ipu_func(_overlap_primitives)(a, b) + assert_allclose(actual, overlap_primitives(a, b)) + + +@pytest.mark.skipif(not has_ipu(), reason="Skipping ipu test!") +def test_ipu_kinetic(): + from pyscf_ipu.experimental.integrals import _kinetic_primitives + + a, b = [Primitive()] * 2 + actual = ipu_func(_kinetic_primitives)(a, b) + assert_allclose(actual, kinetic_primitives(a, b)) + + +@pytest.mark.skipif(not has_ipu(), reason="Skipping ipu test!") +def test_ipu_nuclear(): + from pyscf_ipu.experimental.integrals import _nuclear_primitives + + # PyQuante test case for nuclear attraction integral + a, b = [Primitive()] * 2 + c = jnp.zeros(3) + actual = ipu_func(_nuclear_primitives)(a, b, c) + assert_allclose(actual, -1.595769, atol=1e-5) + + +@pytest.mark.skipif(not has_ipu(), reason="Skipping ipu test!") +def test_ipu_eri(): + from pyscf_ipu.experimental.integrals import _eri_primitives + + # PyQuante test cases for ERI + a, b, c, d = [Primitive()] * 4 + actual = ipu_func(_eri_primitives)(a, b, c, d) + assert_allclose(actual, 1.128379, atol=1e-5) From e6d98abba5dce953b6e17ddcc478d2a16fdceefe Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Thu, 14 Sep 2023 15:08:19 +0000 Subject: [PATCH 04/13] fmt --- test/test_experimental.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/test_experimental.py b/test/test_experimental.py index 165fdd3..a22c58e 100644 --- a/test/test_experimental.py +++ b/test/test_experimental.py @@ -6,21 +6,21 @@ from numpy.testing import assert_allclose from pyscf_ipu.experimental.basis import basisset +from pyscf_ipu.experimental.device import has_ipu, ipu_func from pyscf_ipu.experimental.integrals import ( - overlap_primitives, - nuclear_basis, - overlap_basis, - nuclear_primitives, eri_basis, eri_basis_sparse, eri_primitives, - kinetic_primitives, kinetic_basis, + kinetic_primitives, + nuclear_basis, + nuclear_primitives, + overlap_basis, + overlap_primitives, ) from pyscf_ipu.experimental.mesh import electron_density, uniform_mesh from pyscf_ipu.experimental.primitive import Primitive -from pyscf_ipu.experimental.structure import to_pyscf, water, Structure -from pyscf_ipu.experimental.device import has_ipu, ipu_func +from pyscf_ipu.experimental.structure import Structure, to_pyscf, water @pytest.mark.parametrize("basis_name", ["sto-3g", "6-31g**"]) From 86c29e7470581587ce8e7ab0afc2e50a0f46b788 Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Thu, 14 Sep 2023 15:19:59 +0000 Subject: [PATCH 05/13] molecule factory --- pyscf_ipu/experimental/structure.py | 67 +++++++++++------------------ test/test_experimental.py | 39 ++++++----------- 2 files changed, 40 insertions(+), 66 deletions(-) diff --git a/pyscf_ipu/experimental/structure.py b/pyscf_ipu/experimental/structure.py index 20d9fd2..c1ba34d 100644 --- a/pyscf_ipu/experimental/structure.py +++ b/pyscf_ipu/experimental/structure.py @@ -61,44 +61,29 @@ def to_pyscf( return mol -def water(): - r"""Single water molecule - Structure of single water molecule calculated with DFT using B3LYP - functional and 6-31+G** basis set """ - return Structure( - atomic_number=np.array([8, 1, 1]), - position=np.array( - [ - [0.0000, 0.0000, 0.1165], - [0.0000, 0.7694, -0.4661], - [0.0000, -0.7694, -0.4661], - ] - ), - is_bohr=False, - ) - - -def benzene(): - r"""Benzene ring - Structure of benzene ring calculated with DFT using B3LYP functional - and 6-31+G** basis set """ - return Structure( - atomic_number=np.array([6, 6, 6, 6, 6, 6, 1, 1, 1, 1, 1, 1]), - position=np.array( - [ - [0.0000000, 1.3983460, 0.0000000], - [1.2110030, 0.6991730, 0.0000000], - [1.2110030, -0.6991730, 0.0000000], - [0.0000000, -1.3983460, 0.0000000], - [-1.211003, -0.699173, 0.0000000], - [-1.211003, 0.6991730, 0.0000000], - [0.0000000, 2.4847510, 0.0000000], - [2.1518570, 1.2423750, 0.0000000], - [2.1518570, -1.2423750, 0.0000000], - [0.0000000, -2.4847510, 0.0000000], - [-2.151857, -1.242375, 0.0000000], - [-2.151857, 1.2423750, 0.0000000], - ] - ), - is_bohr=False, - ) +def molecule(name: str): + name = name.lower() + + if name == "h2": + return Structure( + atomic_number=np.array([1, 1]), + position=np.array([[0.0, 0.0, 0.0], [1.4, 0.0, 0.0]]), + ) + + if name == "water": + r"""Single water molecule + Structure of single water molecule calculated with DFT using B3LYP + functional and 6-31+G** basis set """ + return Structure( + atomic_number=np.array([8, 1, 1]), + position=np.array( + [ + [0.0000, 0.0000, 0.1165], + [0.0000, 0.7694, -0.4661], + [0.0000, -0.7694, -0.4661], + ] + ), + is_bohr=False, + ) + + raise NotImplementedError(f"No structure registered for: {name}") diff --git a/test/test_experimental.py b/test/test_experimental.py index a22c58e..3dce7a5 100644 --- a/test/test_experimental.py +++ b/test/test_experimental.py @@ -20,12 +20,12 @@ ) from pyscf_ipu.experimental.mesh import electron_density, uniform_mesh from pyscf_ipu.experimental.primitive import Primitive -from pyscf_ipu.experimental.structure import Structure, to_pyscf, water +from pyscf_ipu.experimental.structure import to_pyscf, molecule @pytest.mark.parametrize("basis_name", ["sto-3g", "6-31g**"]) def test_to_pyscf(basis_name): - mol = water() + mol = molecule("water") basis = basisset(mol, basis_name) pyscf_mol = to_pyscf(mol, basis_name) assert basis.num_orbitals == pyscf_mol.nao @@ -36,7 +36,7 @@ def test_gto(basis_name): from pyscf.dft.numint import eval_rho # Atomic orbitals - structure = water() + structure = molecule("water") basis = basisset(structure, basis_name) mesh = uniform_mesh() actual = basis(mesh) @@ -72,11 +72,12 @@ def test_overlap(): @pytest.mark.parametrize("basis_name", ["sto-3g", "6-31+g", "6-31+g*"]) def test_water_overlap(basis_name): - basis = basisset(water(), basis_name) + basis = basisset(molecule("water"), basis_name) actual_overlap = overlap_basis(basis) # Note: PySCF doesn't appear to normalise d basis functions in cartesian basis - expect_overlap = to_pyscf(water(), basis_name=basis_name).intor("int1e_ovlp_cart") + scfmol = to_pyscf(molecule("water"), basis_name=basis_name) + expect_overlap = scfmol.intor("int1e_ovlp_cart") n = 1 / np.sqrt(np.diagonal(expect_overlap)) expect_overlap = n[:, None] * n[None, :] * expect_overlap assert_allclose(actual_overlap, expect_overlap, atol=1e-6) @@ -90,10 +91,7 @@ def test_kinetic(): # Reproduce the kinetic energy matrix for H2 using STO-3G basis set # See equation 3.230 of "Modern quantum chemistry: introduction to advanced # electronic structure theory."" by Szabo and Ostlund - h2 = Structure( - atomic_number=np.array([1, 1]), - position=np.array([[0.0, 0.0, 0.0], [1.4, 0.0, 0.0]]), - ) + h2 = molecule("h2") basis = basisset(h2, "sto-3g") actual = kinetic_basis(basis) expect = np.array([[0.7600, 0.2365], [0.2365, 0.7600]]) @@ -111,10 +109,10 @@ def test_kinetic(): ], ) def test_water_kinetic(basis_name): - basis = basisset(water(), basis_name) + basis = basisset(molecule("water"), basis_name) actual = kinetic_basis(basis) - expect = to_pyscf(water(), basis_name=basis_name).intor("int1e_kin_cart") + expect = to_pyscf(molecule("water"), basis_name=basis_name).intor("int1e_kin_cart") assert_allclose(actual, expect, atol=1e-4) @@ -126,10 +124,7 @@ def test_nuclear(): # Reproduce the nuclear attraction matrix for H2 using STO-3G basis set # See equation 3.231 and 3.232 of Szabo and Ostlund - h2 = Structure( - atomic_number=np.array([1, 1]), - position=np.array([[0.0, 0.0, 0.0], [1.4, 0.0, 0.0]]), - ) + h2 = molecule("h2") basis = basisset(h2, "sto-3g") actual = nuclear_basis(basis, h2.position, h2.atomic_number) expect = np.array( @@ -144,7 +139,7 @@ def test_nuclear(): def test_water_nuclear(): basis_name = "sto-3g" - h2o = water() + h2o = molecule("water") basis = basisset(h2o, basis_name) actual = nuclear_basis(basis, h2o.position, h2o.atomic_number).sum(axis=0) expect = to_pyscf(h2o, basis_name=basis_name).intor("int1e_nuc_cart") @@ -173,10 +168,7 @@ def test_eri(): assert_allclose(eri_primitives(a, b, c, d), 0.940316, atol=1e-5) # H2 molecule in sto-3g: See equation 3.235 of Szabo and Ostlund - h2 = Structure( - atomic_number=np.array([1, 1]), - position=np.array([[0.0, 0.0, 0.0], [1.4, 0.0, 0.0]]), - ) + h2 = molecule("h2") basis = basisset(h2, "sto-3g") indices = [(0, 0, 0, 0), (0, 0, 1, 1), (1, 0, 0, 0), (1, 0, 1, 0)] expected = [0.7746, 0.5697, 0.4441, 0.2970] @@ -188,10 +180,7 @@ def test_eri(): def test_eri_basis(): # H2 molecule in sto-3g: See equation 3.235 of Szabo and Ostlund - h2 = Structure( - atomic_number=np.array([1, 1]), - position=np.array([[0.0, 0.0, 0.0], [1.4, 0.0, 0.0]]), - ) + h2 = molecule("h2") basis = basisset(h2, "sto-3g") actual = eri_basis(basis) @@ -210,7 +199,7 @@ def test_eri_basis(): @pytest.mark.parametrize("sparse", [True, False]) def test_water_eri(sparse): basis_name = "sto-3g" - h2o = water() + h2o = molecule("water") basis = basisset(h2o, basis_name) actual = eri_basis_sparse(basis) if sparse else eri_basis(basis) aosym = "s8" if sparse else "s1" From becd9d598be92163a6bc32fdd519370cb83986ef Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Thu, 14 Sep 2023 15:38:13 +0000 Subject: [PATCH 06/13] skip water eri test on low mem machine --- test/test_experimental.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/test_experimental.py b/test/test_experimental.py index 3dce7a5..f70a5af 100644 --- a/test/test_experimental.py +++ b/test/test_experimental.py @@ -196,7 +196,16 @@ def test_eri_basis(): assert_allclose(actual, expect, atol=1e-4) +def is_mem_limited(): + # Check if we are running on a limited memory host (e.g. github action) + import psutil + + total_mem_gib = psutil.virtual_memory().total // 1024**3 + return total_mem_gib < 10 + + @pytest.mark.parametrize("sparse", [True, False]) +@pytest.mar.skipif(is_mem_limited(), reason="Not enough host memory!") def test_water_eri(sparse): basis_name = "sto-3g" h2o = molecule("water") From 300be646e986b1c4621da1357b79124f5b18db78 Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Thu, 14 Sep 2023 15:41:23 +0000 Subject: [PATCH 07/13] fix marker --- test/test_experimental.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_experimental.py b/test/test_experimental.py index f70a5af..ad62bd8 100644 --- a/test/test_experimental.py +++ b/test/test_experimental.py @@ -205,7 +205,7 @@ def is_mem_limited(): @pytest.mark.parametrize("sparse", [True, False]) -@pytest.mar.skipif(is_mem_limited(), reason="Not enough host memory!") +@pytest.mark.skipif(is_mem_limited(), reason="Not enough host memory!") def test_water_eri(sparse): basis_name = "sto-3g" h2o = molecule("water") From c87d33b3942436723ceb1e0bd07c104570da7e19 Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Fri, 15 Sep 2023 08:15:34 +0000 Subject: [PATCH 08/13] adding sympy --- requirements_core.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements_core.txt b/requirements_core.txt index 906c37f..42a6fb9 100644 --- a/requirements_core.txt +++ b/requirements_core.txt @@ -23,6 +23,7 @@ imageio[ffmpeg] py3Dmol basis-set-exchange periodictable +sympy # silence warnings about setuptools + numpy setuptools < 60.0 From 0f4625164b1f6b3cddc393f7ad77b40f706db8d4 Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Fri, 15 Sep 2023 08:21:38 +0000 Subject: [PATCH 09/13] Adding notebook on gto integrals --- notebooks/gto_integrals.ipynb | 400 ++++++++++++++++++++++++++++++++++ 1 file changed, 400 insertions(+) create mode 100644 notebooks/gto_integrals.ipynb diff --git a/notebooks/gto_integrals.ipynb b/notebooks/gto_integrals.ipynb new file mode 100644 index 0000000..38d7894 --- /dev/null +++ b/notebooks/gto_integrals.ipynb @@ -0,0 +1,400 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "*Copyright (c) 2023 Graphcore Ltd. All rights reserved.*\n", + "\n", + "\n", + "# Integrals over Gaussian Type Orbitals (GTO)" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import IPython\n", + "from sympy import *\n", + "\n", + "init_printing(use_unicode=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Primitive GTO\n", + "\n", + "The unnormalized Cartesian GTO (Gaussian-type Orbital) primitive is defined as follows:\n", + "\\begin{align*}\n", + "\\hat p(\\mathbf{r}; l,m,n, \\alpha) =\n", + " p(\\mathbf r; \\nu) = ~ & x^l y^m z^n e^{-\\alpha |\\mathbf{r}|^2} \\\\\n", + "& \\text{where} ~ \\mathbf{r} = (x, y, z) ~ \\text{and} ~ \\nu = (l, m, n, \\alpha)\n", + "\\end{align*}\n", + "\n", + "From which the normalized primitive is\n", + "\\begin{align*}\n", + "p(\\mathbf{r}; \\nu) = & \\frac 1 {Z_\\nu} ~ \\hat p(\\mathbf{r}; \\nu)\n", + "\\end{align*}\n", + "where the normalizing constant $Z_\\nu$ is defined so that\n", + "$$\n", + "\\int_{-\\infty}^\\infty p(\\mathbf r; \\nu)^2 \\mathrm{d}\\mathbf r = 1,\n", + "$$\n", + "that is:\n", + "\\begin{aligned}\n", + "Z_\\nu^2 = \\int \\hat p(\\mathbf{r}; \\nu)^2 d\\mathbf{r}\n", + "\\end{aligned}\n", + "\n", + "So, squaring all terms in the integrand:\n", + "\\begin{aligned}\n", + "Z_\\nu^2 & = \\iiint x^{2l} y^{2m} z^{2n} e^{-2\\alpha (x^2 + y^2 + z^2)} \\text{d}z \\text{d}y \\text{d}x\\\\\n", + " & = \\iiint x^{2l} y^{2m} z^{2n} e^{-2\\alpha x^2}e^{-2\\alpha y^2}e^{-2\\alpha z^2} \\text{d}z \\text{d}y \\text{d}x\\\\\n", + " & = \\iiint x^{2l} e^{-2\\alpha x^2} y^{2m} e^{-2\\alpha y^2} z^{2n} e^{-2\\alpha z^2} \\text{d}z \\text{d}y \\text{d}x\\\\\n", + " & = \\int x^{2l} e^{-2\\alpha x^2} \\int y^{2m} e^{-2\\alpha y^2} \\int z^{2n} e^{-2\\alpha z^2} \\text{d}z \\text{d}y \\text{d}x\\\\\n", + " & = \\int x^{2l} e^{-2\\alpha x^2} \\text{d}x \n", + " \\int y^{2m} e^{-2\\alpha y^2} \\text{d}y\n", + " \\int z^{2n} e^{-2\\alpha z^2} \\text{d}z\\\\\n", + " & = ZZ(l,\\alpha) ZZ(m,\\alpha) ZZ(n,\\alpha)\n", + "\\end{aligned}\n", + "where \n", + "$$\n", + "ZZ(k, a) = \\int_{-\\infty}^{\\infty} t^{2k} e^{- 2 a t^2} dt\n", + "$$\n" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [ + { + "data": { + "text/latex": [ + "$\\displaystyle ZZ(k,\\alpha) = \\int\\limits_{-\\infty}^{\\infty} t^{2 k} e^{- 2 \\alpha t^{2}}\\, dt = \\left(\\frac{1}{2 \\alpha}\\right)^{k + \\frac{1}{2}} \\Gamma\\left(k + \\frac{1}{2}\\right)$" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 42, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "a = Symbol(\"alpha\", positive=True, real=True)\n", + "k = Symbol(\"k\", integer=True, nonnegative=True)\n", + "t, x, y, z = symbols(\"t x y z\", real=True)\n", + "\n", + "ZZ = Integral(t**(2*k) * exp(-2 * a * t**2), (t, -oo, oo))\n", + "IPython.display.Math(r'ZZ(k,\\alpha) = ' + latex(ZZ) + ' = ' + latex(simplify(ZZ.doit())))" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "data": { + "text/latex": [ + "$\\displaystyle Z^2(l,m,n,\\alpha) = \\left(\\int\\limits_{-\\infty}^{\\infty} t^{2 l} e^{- 2 \\alpha t^{2}}\\, dt\\right) \\left(\\int\\limits_{-\\infty}^{\\infty} t^{2 m} e^{- 2 \\alpha t^{2}}\\, dt\\right) \\int\\limits_{-\\infty}^{\\infty} t^{2 n} e^{- 2 \\alpha t^{2}}\\, dt \\quad = \\quad \\Large \\left(\\frac{1}{2 \\alpha}\\right)^{L + \\frac{3}{2}} \\Gamma\\left(l + \\frac{1}{2}\\right) \\Gamma\\left(m + \\frac{1}{2}\\right) \\Gamma\\left(n + \\frac{1}{2}\\right)$" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "l,m,n,L = symbols(\"l m n L\", integer=True, nonnegative=True)\n", + "Z2 = (ZZ.subs(k, l) * ZZ.subs(k, m) * ZZ.subs(k, n))\n", + "Z2_expanded = simplify(Z2.doit()).subs(l+m+n, L)\n", + "IPython.display.Math(r'Z^2(l,m,n,\\alpha) = ' + latex(Z2) + ' \\quad = \\quad \\Large ' + latex(Z2_expanded))" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Help on function _lambdifygenerated:\n", + "\n", + "_lambdifygenerated(l, m, n, alpha)\n", + " Created with lambdify. Signature:\n", + " \n", + " func(l, m, n, alpha)\n", + " \n", + " Expression:\n", + " \n", + " (1/(2*alpha))**(L + 3/2)*gamma(l + 1/2)*gamma(m + 1/2)*gamma(n + 1/2)\n", + " \n", + " Source code:\n", + " \n", + " def _lambdifygenerated(l, m, n, alpha):\n", + " return ((1/2)/alpha)**(L + 3/2)*gamma(l + 1/2)*gamma(m + 1/2)*gamma(n + 1/2)\n", + " \n", + " \n", + " Imported modules:\n", + "\n" + ] + } + ], + "source": [ + "f = lambdify((l,m,n, a), Z2_expanded, modules=\"scipy\")\n", + "help(f)" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "data": { + "text/latex": [ + "$\\displaystyle \\Large{z_{000}= \\frac{\\sqrt[4]{2} \\pi^{\\frac{3}{4}}}{2 \\alpha^{\\frac{3}{4}}}}$" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 45, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "z000 = sqrt(Z2).subs(((l,0), (m,0), (n,0))).doit()\n", + "IPython.display.Math(r'\\Large{' + latex(Symbol('z000')) + '= ' + latex(z000) + '}')" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [ + { + "data": { + "text/latex": [ + "\\begin{align*}~\\\\z_{000}= z_{000}\\\\z_{100}= \\frac{z_{000}}{2 \\sqrt{\\alpha}}\\\\z_{010}= \\frac{z_{000}}{2 \\sqrt{\\alpha}}\\\\z_{001}= \\frac{z_{000}}{2 \\sqrt{\\alpha}}\\\\z_{110}= \\frac{z_{000}}{4 \\alpha}\\\\z_{101}= \\frac{z_{000}}{4 \\alpha}\\\\z_{011}= \\frac{z_{000}}{4 \\alpha}\\\\z_{111}= \\frac{z_{000}}{8 \\alpha^{\\frac{3}{2}}}\\\\z_{200}= \\frac{\\sqrt{3} z_{000}}{4 \\alpha}\\\\z_{220}= \\frac{3 z_{000}}{16 \\alpha^{2}}\\end{align*}" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "out= r'\\begin{align*}~'\n", + "for vl,vm,vn in ((0,0,0),(1,0,0),(0,1,0),(0,0,1),(1,1,0),(1,0,1),(0,1,1),(1,1,1),(2,0,0),(2,2,0)):\n", + " zllm = sqrt(Z2).subs(((l,vl), (m,vm), (n,vn))).doit() / z000 * Symbol('z_000')\n", + " out += r'\\\\' + latex(Symbol(f'z{vl}{vm}{vn}')) + '= ' + latex(zllm)\n", + "out += r'\\end{align*}'\n", + "IPython.display.Latex(out)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# From factorial" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1.0 1.0\n", + "1.0 1.0\n", + "3.0000000000000004 3.0\n", + "15.000000000000004 15.0\n", + "135135.00000000003 135135.0\n", + "316234143225.00006 316234143225.0\n" + ] + } + ], + "source": [ + "import scipy\n", + "p = Symbol('p', integer=True)\n", + "\n", + "# n!! = special.gamma(n/2+1)*2**((m+1)/2)/sqrt(pi) n odd\n", + "# = 2**(n/2) * (n/2)! n even\n", + "\n", + "def fac2(n):\n", + " assert n % 2 == 1\n", + " return gamma(n/2+1)*2**((n+1)/2)/sqrt(pi) # n odd\n", + "\n", + "\n", + "for n in (-1,1,3,5,13,23):\n", + " v1 = scipy.special.factorial2(n)\n", + " v2 = float(fac2(n).evalf())\n", + " print(v1, v2)\n", + " assert np.isclose(v1, v2)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAcAAAAAVCAYAAADRoT5bAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAABJ0AAASdAHeZh94AAALQ0lEQVR4nO2df7BVVRXHP09UIDBTHLWUHxIpKvYuvGxsGlEDLNLSZzo4YyQ2lUORv8IxmWyx1MxGDLQpqcERnJE0gn6pUDAqaigGSagpmQiYP1PMQFMR6Y+1D5537jn3nnvPvvfcK+c7wxzePvvstc/6sfdea699bseOHTvYVaGqS4BHROTivPvSLlDV0cBUoAv4CHCOiMyt8sx0QCLFL4rIgY3oY4ECBQqkwW55dyBnlICHq1VS1bluEG8qmkW3Rjr9gUeB84H/1UBmHfDh0L+jauljgQIFCvjG7nGFqjoPGA8cIiKvN7dLfqCqXcAq4OsiMifm/oHAAcAajzQPBp4BzgTOBY4BngQmYBPHj4FPYBPI6SKyqQG0TwUmA8cCLwBfE5G7fdERkTuBOx3NuTU8+o6IvOCrH2FU0ldVvRDj+1kiMr8R9FsF1XQ+Q7ttPx7kiUIu+aIS/8smQFU9GpgITA0zVVVPB47DvKZOYC/gFhH5cooOHAA8B/wMmA50AydhXsBBwNvAI8BNwE0i8m6tLxmFiKxW1d8CV6jqrSKyNVKlhHkw67LSCqHTXSdjIb/NwAJgLvAGMA3YAvwGuAi4oAG0LwKuAKYA12KD/0iPdOrFUFV9DngLWAlME5H10UqqmiYmv4+I/MfVj9XXELrc9a919boO1GsrWe0khc43gr9tg7zGsEIu+c4flfgf5wH+APgvcEOk/Huu41uBfwHDq3U8hFOwcOsi4AzX9vPA3cAmzBM7DZgDjFfVM0TEx+bkD7HB9jzgqsi9Erb/t90DnXCbrwETRORFAFVdinmEw0XkFVe2HAsD+kQJk9uEwNNS1V9jPMgbK4FJwBPA/pgurVDVIwOexEArtPdm6P9J+hpgFKaz/6ilwxlRr634sJNKOh+GL/62E/Icw3Z1ueQ9f8Tyv8cEqKqHAmOBOSIS3d+50HX8n9hMXktYrRt4BbgX6AC+CNwRnqlVdRrwEPAl9zILa2g/FiLykKo+AZyrqldHVgYlEsKfri/TQkW9gR2qOjVUNl5E7os82gncHkx+DoOARZGBfhDwF490A9p3RMKMwzB5+aRTM0RkcYT+g8B64GzMQ417Znq1dqvoK6raDzgMWFFLVEFVJ2GryRNE5J60z4VQr61ktpMqOh+uN71aZ6rxt9nIUS6QUTbvB7lk5H+u80cS/6Me4FcdgduiPQjvI6lWWqT0hKp+EPgMMN95W3fF1RORF1R1NrayOT76Aq6tPYBvYd7EYRhTFgCXuHd5GlgmImeFHrsVc5vHAX8MlZeA6xK6PRv4VejvHwHPAteHyp6Nea4UqQMWfvx+pKwT+IVHugHtn8TQXuOZTmaIyFZVfQz4WMamEvXVoYStHJsW/oT6bMWnnZCs87WiIn9V9QTXz2uB+ZiejwZ6ufLJrr9HYB7AWGyxdR/wTZ974GnQAmNYs+QyDvgTcDVwC7bYHQP0A9YCF4rIygz0a0YL8B5i+B+dAMcC24EHU/ewOk4G9sT2vaphm7u+E72hqvsCS4CjgduxFzgZ20d7FngX2JfydPs/u+vOl1bVD2CDb2wGqIhsxvbvAtpbgM0iUuZNher0Az4ablNVBwADI2UDgQFxtOuhm0TbYSQWNvBCxxdUtQ8WAsmanFNNX0e56+qMdJoBL3biUKbzdSItfw8F7gcWAzcCn8OSsXqr6g3AL7GBax42OJ0E3Oz+3w7wJZtmySXY8z8UizQtxXh/JPB54PeqOkxEtmToQ7PQULvYOQG6QbQEPO55U7UbeB1bkSRCVXcHvuL+XBJT5TZs8jtfRK53z1yDudXjgSOAuTGDeBBqHB0q+7i7rk35DmkQtLkmVFbCkj4ei5RtJSY06ZO2m3wPxmOWq2u3PxZaBfOuBqlqCZtAN7k6U4ApIjLc/T0D+AMWr98fuAxbjc7L0I80+tr0BJgM8GUnEK/zNSElf4MJ8JPAMSKy1j17OSbrE7HBeJyIPODu7Ynp/mhV7SMib5Y323LwJZtmy+VY4NMislP/VXUhFiIsYZ54q6OhdhE+B3gQFrp4vq5uxneqD7YaXJxC0a8GRgB3ikiP1ZGqjsVWPfcRCvOJyMvABsxF3ge4PNqoiLyGbR4PChWXgCdF5I3a3qgiOl2bYaUcCTwqIu9E6v3NR6ZrDO1wdtlIbEX0d490wI5xPOz+9cU27R+mJ+/3w0LUAQ7GvIB1mEf6FjZgbszQjzT6OgrL9H08A52Gw5edBEjQ+VqRlr8Ak4LJz9HfgtllL+DiYPJz997G9KADWwS1NHzKpolyCTzAc8KTn0NgC30y9KEpaIZdhEOgA9z11bp7XI5x2Pm3iu6rqp4HfAfLEpwYUyUomxWT3RMw5uci8kwCic1YphAAIjIb2wdLBRGZlKJOWZsiMgOYESm7ErjSF90KtJdhoQNvdFzde7DBq1Kd6VisPfj7zLTt14CK+uqM53BgVaVMX1XdAAxOuH13zH7FvFr4lRK+7CSMHjpfB6rxtx8WYlsvInEr7sGuD3H7VIOBLRUygFtFLuBfNo2WS39se2cj7rxuBEPd9alKRFqE/w23i/AEGGQT+VwZnIad0bgjqYILl12HeSpj3P5UFMdh3kySa/sGlVOL+1LbV0sKtD6q6Wsnpt/Vwp+zgA9FykpY6vU8zJMJY0267tUEX3YSRladT8Pf3YBl0RuqOgSLyCwSkW2Re0EI/f4q9GeRv1zAv2waLZcStkBdmnCUbBR2VOvpKnRmkT//G24X4QnwJXcdgAeoai/gC8BdzvWMq3MBMBP7MsoYEXkppk5fzGV9KhqyVNWhWDLFisjRg3Cd3TBBVhN4gfZCNX1NlQAjIrOiZS7d+xRsT/me+rqXDr7sJFLfh85n4W9XhXsjsQG64sIkb7k4el5lk7dcVHUvzGtfnjA57kTe/G+WXYT3AJ8H/k3PvZssGI0JKdZ9VdVLsM6vwc6VJHW+L2YwcXtmM7G06qSsH7D36aBxK8QC+aCavrZLAowvOwnDh85X428w0K6KuddV4V6wP9XqcgH/smmGXAL+JvG+g/bIim6KXeycAN2K4F5gP1UdVvZo7ejGJq3fRW+o6mXYpuVqbOZ+uUI7r2JZk8NUNch2RFUnYwciodxVD+MYd/X2PcwC+SOFvo6iPAO3FeHLTsLIrPMp+fs2tvqOopIHGEyc7TAB+pZNM+USl+HeLotCaJJdRM8BLsRO0n+WSJq+qp6Kne0BCH7G5lP63geRXxaRqa5uh6v7QDQ0qapnYxmD27GszvNiNlQ3iPuJHRHZ4WhMAZap6gJHvxtjzt7A8e4Q5I0iEv3CyomOVhkjC7Q9YvXVpdqPwLJttyU82zCktRWfdhKBL51P4m9v7NjRWpfVGUUXsDEhySW3zNy8xrAQWkEukIMH2AK8hxj+x02AL2HnKX4auVfCPl0VxlDeyyraiP1OHFiq/EBsIzWKQ9y1F8kfg16OfUA6wMXYan4C8A3MK5wJfBdz62/Gfn2hx+l/Vd0bY+TtFTJEC7QvkvT1KGAP8lvplkhnK77txLfOJ/F3BMbfuH2mwVjoannMvd5YZu7qSpm5DUSJnMawJspld+LDn2ATYLO/ixugRH7zRyL/O6I/iKuql2IZlaNEJPZLKdWgqlcBlwJDRSS35BNV/Tb2ea9jRaRa1lmBNoQPfc0LjbAT3zrfzvzNAt+yKeSSHs20i7gfxJ2JfcWh7FB5DejGwk95Tn59MSYuLCa/9zV86Gte8GonDdL5duZvFniTTSGXmtE0uyibAN2J+4nAKnfYtWaIyOEiUqrnWY8Ygn1wemqVegXaGD70NS80wE6G4Fnn25m/WeBZNkMo5JIazbSLshBogQIFChQosCvg/4NessAgH0GVAAAAAElFTkSuQmCC", + "text/latex": [ + "$\\displaystyle \\left(\\frac{1}{2 \\alpha}\\right)^{l + m + n + 1.5} \\Gamma\\left(l + \\frac{1}{2}\\right) \\Gamma\\left(m + \\frac{1}{2}\\right) \\Gamma\\left(n + \\frac{1}{2}\\right)$" + ], + "text/plain": [ + " l + m + n + 1.5 \n", + "⎛ 1 ⎞ \n", + "⎜───⎟ ⋅Γ(l + 1/2)⋅Γ(m + 1/2)⋅Γ(n + 1/2)\n", + "⎝2⋅α⎠ " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYsAAAAVCAYAAAC3+8IMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAABJ0AAASdAHeZh94AAAKdUlEQVR4nO2de7BXVRXHPxdQYITwUaMzTvGIKA3ieq82mIqYoJEPuKZjk5FUU46PSCYcw7LFMjXKSa42hjU4gpPkYzAqEUseIgIpUISa5PgAjUjikYGPQKA/1j73nnvuOed3zu93fr9zf3C+M3fOvfu5zlpr77X3Xmuf23DgwAEKFChQoECBOPTIm4A8oaqLgc0i8pUqtX81cAUwwCW9ANwsIgtCyt4NvAe8CVwEfBz4H/AnYKqIPF8NGgsUKFAgCbrlTUDOaALWpqmgqrNVdVrC4v8Arnf9nAwsAear6qcCbTYAFwLzgVHAz4HPAJ8F3gcWqerRaegsUKBAgSwRurNQ1TnAWGCgiLxdW5Kygao2A2uAb4jIrJD8jwJHktJYpIGI/DaQ9D1VvRI4FVjvSz8F6Ak8LSLnBuicALwFnAb8vlKaomSrqpOB24HLRGRupf10dZTSjwrarfuxkycKueSLOP53MhaqegowAZgSmEwuBs4EGoHhQF/gfhH5cgICjgX+ia2YpwEtwHnAMOB4YA/wHHAvcK+I7E/7kkGIyFpVnQ/8UFUfEJHdgSLNwH5gXaV9JYGqdgcuAfoAKwPZ44EFIvJ+SNW+2A5wZ6C9JM6mo0TkP746obJ1aHbPPydoNzOUq1eV6lQC/UjN4xL8rSsUcskPec61cfwP21ncAvwXmBlI/74jfDd2vPKJUoT7MA6b8B7BJsyZwBZgKfA6cCx2Tj8LGKuql4hIFp73HwHPAJOAWwN5zcBLYcqYJVR1GLAK6IXxrkVEngsUGwfcGNHEHZhBWxXVRUz37wX+jpIt2FHZbuClmPaqgXL1KguditMPP5LyOI6/9YZCLvkh77k2lP8djIWqDgFGA7NE5N1AA5Md4S9jVm9pihdoAbYDTwHe+fwCv1VT1RuAZ4EvuJeZl6L9UIjIs6q6AbhCVacHrGgif4Wj6wZfUk/ggKpO8aWNFZHlEU38HVsh9AMuBuao6ijPYa2qg4FBwB9C+r4dOB04XUT2RbzjtFLv4NqKlK2qHoE51Fem3dWp6kRslXKWiDyZpq5DuXpVsU6V0A9/uWmliCkxdmqOQi5tfdVcLjnyHqrI/+DO4muugweDFIhIG8GqcQa9I1T1A5ijdq6b8JaElRORf7mIoFswJ28nY6GqhwFXAxOxyW078DDmRO4BvAYsEpHLfNUewLZjY+g4ITcBNyd4hbuBh3x//xjYDNzpS9scVVlE9mBCB1jrtsSTga+7tPHA4uD2WFVnAF/EFO7VBHSWQqRsMWPWjRofQUF5epWlThGtH2kRx19U9SxH50+BucAPgJFAd5d+paP3RGxlORpbmCwHrhKR1yugLTUOIbmMAf4ITAfuxxaGZwNHYH7FySLyTAX9p0ZXmGsJ4X/QWIwG9mHhmlnhfOBw4DcJyu51z05n9y4a6HHMGfwo9gLnA9dik/V+4GhAAlVXuGfbS6vqQFe25OQoIjuAHT46dgE7ROTl6Fqx6IZNAh7GAXP8BVT1DuBSzFBsKLOfIOJk2+SeVXP2Z4xMdMqhk36UiVJjx+PxEOBpYCFwD/A5bMHQU1VnAr/GBvkcbCCfB9znfu/qqEe5nOSeQ4DVwBMY7z8JfB74naoOFpFdFdBQK1SV/23Gwh1FNAIvZuwEagHexqx3JFS1B+Ddd3g8pMiDmKH4tojc6erchm3XxgInArNDJvHV7jnSl+Y5c/eq6lBf+j4ReTH+dZJDVacDC4A3MCfVl2ifAFDVDwEjsOMpr85dmDNuPLBTVY9zWbvL9a8kkG0uzu0KkJVOQbh+pELCseMZi08DI0Rkvat7E3aWfA42cY0RkVUu73BsVzpSVXuJSNAH1dVQz3I5AzhNRNrGgKrOw45pGrEdXldHVfnvv2dxPLYl3lIWmeFE9cJWTgsTKPp0YCjwmIh0WEmo6mhshbAc+JmXLiLbgI3Y1uso4KZgoyLyFubs+ogv2ZscV2CRAd7PQ2SL44BfYX6LxZixGysiC13+BcBqEXnTV+cqzLAsxmTh/fh9JGlRSrZNwLtAZoayWshKpzxE6EdaJBk73qQ00TMUrv9dmA53B67zDIXL24PpTgN2LNJlUcdy8XYWX/UbCgdvPPSqgIaaoBb89x9DHeOeHUI0K8QYLFQ0dlukqpOA7wAbsFV1EF5aa4jn3mPML0TkjYgudmBRAACIyFRgajzp4RCRiRmWHYddxPPXaUhNVGlEytYp2QnAmignuq/sRqB/RPbSkPPVOWn4lRBZ6ZQfHfSjDMSOHbfCHQK8KiJhK7n+joawc/X+wC4R2R7VeSGXSJSSSx/gY8Am4LGQIoPc85WoDroI76EG/PcbCy9SIEsrehEW19vp8xYeVPUaLDz0b8DZzkcQxJnYGVvUlukd4kPsetP+fl0JK7Az6mojTrbDMT1IcgTVil1k9KORdr/LxkDeumTkpUJWOuVHpfpRauwMx3bxi4IZqjoA2xU/IiJ7A3l9gMGYjyMOrRRyCUMpuTRiu7YnIkL1m7ALsa/F9NFK/ryHGvDfbyy2uucxZAB3Ce0CYInb0oSVuRaYATyPEb81pExvbCv0ioi8E8gbhMUgrwwc5fjLdMOEGSfwXCAiP6lRV3GyTezcFpHWYJoLExyH+YueLI+8ZMhKpwLls9CPUmMnjsfNMXknYZNZrCEv5BKJsuWiqn2x3eCyuDtfefPe9VcT/vt9FluAf2MhqVlgJCak0G2Rql6PEb8Oi/qJIr43NmDCYq1nYJFFUR59sPdpoEY3tbso4mRbT87trHTKjyz0o9TY8SalNSF5zTF53nl6V5dNvcrF428U7xuojwjBmvC/zVg46/kU8EF3UaxStGATfPD7SKjqjZiTZS1m5bbFtLMTu8k4WH0f4FP7xtKF7s8jY+qPcM80F1sOKpSQbRP2ddsXak5YemSlU35UrB8Jxk4TdkQQ9uXguJ2FZ2S6urGod7msD8mrp0VUTfgfvGcxD7vVdy7tF8m8Tsdj4ZxgUT4Ap6rqbPf7NhGZ4so2uLKrgsdDqno5FrW0D4tumhTiBNooIrPBBO76uAb7+urDrv8WjDn9gFHuksk9IrI60NY5rq9OjDzE0Em2LjRzKPDX4Hl5rZBUr7LUqQCy0o/QsaOqPbGw7vUuuimIZmBThAM7tyi1Qi6xRryqyGuuDaAT/8OMxVYsBveuQF4jcHkgbRDtEQObaA/vPBn4MOb8CWKge3bHLtSFYRkw2/f3ddjq91Lgm9huYwbwXWy7eB/2fyM63ERU1X4YIx+NiZQ6VBAm22HAYeS7emokmV5lrVNZ60fU2BmK8TjsXLw/dnywLCSvJxaltrZUlFqV0MjBL5cehB9BgRmLPL6VBvnOtZH8bwj+pzxVnYpFFjWJyF8iOoiFqt6KhaYOEpHcHMuq+i3ssxxniEipiJKDHlnINi9UQ6ey1o965m+5KOSSL2rJ/7B/fjQDu1Ha6YJbCrRgRxt5GoreGBPnFYaiDVnINi9kqlNV0o965m+5KOSSL2rG/07Gwt3+mwCscZeJUkNEThCRxnLqZogBwC+p7ObzQYUsZJsXqqBTA8hYP+qZv+WikEu+qCX//w8lQz8AqrJJcgAAAABJRU5ErkJggg==", + "text/latex": [ + "$\\displaystyle \\left(\\frac{1}{2 \\alpha}\\right)^{L + \\frac{3}{2}} \\Gamma\\left(l + \\frac{1}{2}\\right) \\Gamma\\left(m + \\frac{1}{2}\\right) \\Gamma\\left(n + \\frac{1}{2}\\right)$" + ], + "text/plain": [ + " L + 3/2 \n", + "⎛ 1 ⎞ \n", + "⎜───⎟ ⋅Γ(l + 1/2)⋅Γ(m + 1/2)⋅Γ(n + 1/2)\n", + "⎝2⋅α⎠ " + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "l,m,n = symbols(\"l m n\", integer=True, nonnegative=True)\n", + "alpha = Symbol(\"alpha\", positive=True, real=True)\n", + "\n", + "def current_squared(l,m,n):\n", + " L = l+m+n\n", + " N = pi**1.5 / (2 * alpha) ** (L + 1.5)\n", + " N *= fac2(2*l-1) * fac2(2*m-1) * fac2(2*n-1) / 2**L\n", + " return N\n", + "\n", + "def gammabased(l,m,n):\n", + " def fac2pm1(p):\n", + " return fac2(2*p-1)\n", + " \n", + " L = l+m+n\n", + " N = pi**1.5 / (2 * alpha) ** (L + 1.5)\n", + " N *= fac2pm1(l) * fac2pm1(m) * fac2pm1(n) / 2**L\n", + " return N\n", + "\n", + "\n", + "display(current_squared(l,m,n).simplify(), Z2_expanded)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Next steps" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And now, evaluate over the range $(0,1)$ rather than $(-\\infty,\\infty)$:" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAIcAAAAoCAYAAADUrekxAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAABJ0AAASdAHeZh94AAAHqElEQVR4nO2cf5BVZRnHPytY/lgaxyGgGKoxlYGx2Ni0GSGECjRTY8gZwwkLpEYslBxIMfXrV1HBQpdxdEomJYlqMqwBlWLUSItSskC3JmKoCC0IxDJHS4Xtj+ddOHvZu/fu7rl7uXK/M2fec877vud57nuf8/x8721oa2sjL9geB8wFmoF3AtMlLcuNQB19iiNyfl4j0ApcDrya87Pr6GM05Kk5srD9MvCluuaoXeStOep4E6EuHHUURdnCYbvBdr9KMlPHoYX+pQbYPgoQcAlwpO27gPmS9laaucMRtocBy4FBwBvAjZLurwYv5WiOpcBVwHHAscA84MoK8nS44w1gjqSRwCSgxfax1WCky2jF9iBgB9AAXJbOfwA8J2lYJ+MbgRPT5XpgIbAK2CPpb/myfhDthUCzpImVpNPXsL0JOEfS9r6mXcqsNBGCAbAC+DfwMLDL9jGSXikY/0HgZ5lrp+PbwOd6yqTt+cAUYDjwP+DXhGlrLeB1Y09pHIqw3Qz0q4ZgQGnhGJja/0rak84/UWywpHUcEKY8MR64C9iQnn8D8IjtkRm+moD7KkC7KrB9PPF5Pl8tHkoJx4DUvlRpRrqCpDOz17anEVpsDLDa9hBgMElzJBv9LeBkYIqkv/YVr+ltvxU4HXgemAEMAS6XNMb2+YQWPlnStjRnCXAOcLqknbbfCvwYWChpfV/xXohSDmljal+uNCPdxACC9xfTdRORrt9sezjwFOHYjeljwTgVeIIwre8nzJ+BrwLXpmErgWeBa9KcucBU4KwkGA3AMuAxSctz4muY7TbbD3RnXrma4z89Y6tiWEJoiV+l6yZiwScDdxPhX0sV+FoMrJa0AMD2d4HVwOOSHgOQ1Gb7auAh21uBq4GPStqSnjEGuAB4xvbkdG+apGd7wVdzan/bnUmlhOOQ0xy2bwPGAmMzuZYm4CTgHuA8ST/PidYC4q3vChMkrUum7cPAhEzfa4SGuzY7QdJa2xuABcC5kjZk+n5B/pnr0al9ujuTakpz2L4d+DTxhfw509UEPABcCByfI8kW4DslxrSH6CNS+5tM33Bgc/rC98P2R4BRhHO9s/dslkTtaA7ba4GJwPmSVmbuNwD3Ap8FFkm6KtO3hFC3EyT9MXP/GEJrTCNs/XLb4yQVXQjbnwJmAqcCbwO2JbqLsplfSbuB3WV+rOOANmBvojGA0Do7CmiPAn4EzCYiv1uADg53T2C7P/BF4GLCEd9JRHi3Eprj75J2prFlrX8p9VUpzTEP2AfcWFCv+Xpi7O4CwbgTmE5ohhdtD0lHI+H4tQGtklYAtxMRzNBCorb72f4e8EMiWXc/sYD7gJsIs9RTbCQ0wfzkFK8A/gG81/ZJif67gTXAYkn3EGWJibbH94Iutt9C5J9aCOG8A3gUuJ7wwYbQ0aSUtf5ViVYkbSLqByOIN57kpF1BZGBnFUy5lBDUR4kFbz/mEiZli6T2zUXXAb8EViWtksUSwiwtBEZImiVpDnAKkdG9yPbIHn6mvxCaYhawiXihPkZsflqf8hY/IRzWG9KcVkJAb+kJzQzuJDTBdcBoSfMkzSA00sw0Zr8mLXf9S6XPNxBZz5sllXLMuoVUYPoToXYXE9L+U8KhfC1PWoneh4joZpWkyZ30fwH4JjBD0r15068UbJ8GPAk8KOncTvr/QAjBeZJWZ+6XXP+qRSuStttuIYp6dxBv7pRKCEbCbELtv2L7+k76T0ltre1xmZ3am4r0v5DaDj5YOetf7WhlV+b84k5qNXliUmqnlhi3rYI8VAKTCAF4skj/CcA/JT3fSV+X6181zWH7QsIB2kFKL3Owr5EXraOAtxPJqDMqQaMaSJ9rEPA7SQf5B7ZHE78CWNNJX8n1L9chzVVz2D6bSBG3EtHGZmBm8vIrgfZi4MAuR9Ue9qZjUJH+dj+xg0kpd/2LCofto4H2MCc3zWF7LBFKPgecKWkXUWfoDyzKi04WKZJ5Bhhpe0oxvmptG6Sk14EtwFDbHZxR21cS2xwgE8Z2Z/2LRitpo0979m5MHtVB203AOqJINlbS1kxfe2Q0TtITvaXVCe1JwEPEIjxCCMsRwFAig3ikpHflTbfSsH0RsV/mdeD7hJkYD7yP8CmGAe+RtK2769+VWWnMnPdac9g+kYjz2wiJ3VowZH5qv9ZbWp1B0lqijL6SWLjLiBh/BCEs0ytBt9KQdB8wh9AEU4kk1naigNcGvJAEo9vr35XmGMWBnVUnpCRPHYcRuopWBmTOu3RIbV9KpGTfAfye2CCbu2moo2+x36zYPs32w+mnBxC2CmLTzJ6DZh6YdwGRlr4Z+ACRTFlju+bsdx0dkdUcu4GPA6/aXgp8Jt3fKGlfF8+4AlgmaWm6nm37LCJmnl98Wh2HOvZrjrQ/4kHgaCIuPjt1LSw2OVUDm4G1BV3tzl8dNYzCaGUaUWp+iagsfjJb7+8EA4lcSOGGlZ1E1q2OGkYHh1TSvzhgTuo4zNHbCuRuIn07uOD+YAp2QNVRe+iVcKTy7tPERpMsJhJRSx01jJK/si8DtxH7Np8idmBdQlQCv5HDs+uoInL526eUBPsKkQRrBb4s6fFeP7iOqqJi/wlWR+2j1rbE1dGH+D9jv+5kjCDNVwAAAABJRU5ErkJggg==", + "text/latex": [ + "$\\displaystyle \\int\\limits_{0}^{1} x^{2 k} e^{- \\alpha x^{2}}\\, dx$" + ], + "text/plain": [ + "1 \n", + "⌠ \n", + "⎮ 2 \n", + "⎮ 2⋅k -α⋅x \n", + "⎮ x ⋅ℯ dx\n", + "⌡ \n", + "0 " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAANMAAAAXCAYAAACRfnp7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAABJ0AAASdAHeZh94AAAIpUlEQVR4nO2be7BXVRXHPzx8UJqSlUyWDQyBMiFXHiGJmCljTgRcErXgKtUUk5KPBAkrv61KYypeOiQq5sUIH4X5ArF8IIKWyKORwcuAgmiZyGMsfBVw+2PtH57OPb/XOb97L1x/35k75/722Xutvc/Za6+1vnufdo2NjbQkzGwJ8JykSc2oYwgwEegHfBz4uqT6PHXnAO8ArwGjgJ7Au8BfgCmS1jVXP6toW2jfCjprgDXNrOMIYB1wGfB2vkpm1g4YDtwLfB74NfA54AvAHuARM/twM/e1ijaCji2pzMy6AMcCa8PvDwK3Aj2AUZK2VEKPpMXA4qCjvkDVAcBhwHJJZ8f6Wge8AZwKPJC1T2Y2DzgH6Crpzdi9K4DpwBhJC7LqqqIpzKwf8CzwLUlzm0NHixoT7pXeBjaYWU/gHtxLnSoprwdpRowEFknak3DvSNxz78qqxMwGAHXAxLghBfQL12ez6soKMzsXOB1/V33w5/A7SWNLaHss8A/cw/8YqAW+BPQGjgP+AzwH3AbcJmlf5UeQDEmrzOxe4Kdmdqek3fnqph1HS4d5NaETI4GngFskjW0lQwIYgYd4SZiFe9CnK6DnWuBfwI157vcN9zdWQFdW/BCYgL+rv5fZdgQ+p+4BRgO3AAOBvwIzgYXAZ4C5wN0hzG5J/BzoAlxapF6qcWT2TGb2M+AHRaqdIWkp/oI+DfwGGC7piQrJLRtm1h3oBjyccG86MBgYLGlvGvkRWT2As4C5SYtGCHV7Ak9KyswGmdk4fMVM+2yuAF4BNuEe6vEy2tYCO4BlQC4fXRT1QGZ2NfAM8BWc8FmYoo+pIOkZM2sAxpvZ1AKeMdU4mhiTmR0CXAKMw1/yDuD3wGTc+DYDj0gaE5rMBOYXGcfWcK3Brf1rQLHEvhy5aTASeDQhf5kBXIBPxhdj92qBu4HVwHmSXooLNTMDrgHGS7oZ+Ab+Qu7K048afBVcFZPTGajHX+QsYJKk/5Y1whSQtN94fCilwcw+hBM3C8IC9Fge+f8MDOq1OOmT2phSzFWAO/HQbSjJC2nqcfyfMQXmagmemD8YlA0DLsdd/j7cCBQRuh3YXsLAP4B7pTp8tfutmQ2RtDpPZ0uSmwEjgHmxPs4CzscNqSGhzTbgzziRcBX+IqPtu4XylXgIAO6V9uJUexL6hut+YzKzgbjxdQZGS/pDyaNqPQwDDgX+WELd3KKQlKuWhDRzNWBFuCYaExnGEfdMd4XOXSbp+tDpX+Ju/xygF1AvaVMJiuI4CWgE1klaaWYnAA+Y2WcllRubF4SZHQF0Dz/bA8ebWQ2wU9JWM/socApwbqTNbNzQRwK7AvMIsDuXrEpaEbzTDqB/gurr8RdxsaR9IYSrAZ7PQzzAe+TDqtCP7wFTgQZgqKQDIY8qBbXAm8CfClUys47AheHnkgz60s7VleE6JI/c1ONoH7l5Fr6KPgnckCsPHmIL7vo6Az8ppKQAaoCNkbzhGnyVuD94rUqiP84SrgE6ARb+z/X9y8BKSa9F2lyMM1ePAq9G/iZGBUt6F1gPnBAtN7PhOONzs6QcK3cc0CHIyYe+wL+B7WZ2HzANWAAMPFgMycwOB74IPCTpnSLVp+LJ+2JJSZ6hFH2p56qkN/BN+uMrPY6oZ6oL15kJiXBO8E2SXi6iJBGS5gBzIr8bgfPSyCpB11I8T8mHJiyepHKYpQZggJl9QtIrZtYJz/G2A1dH6h0Tron0enh5J+Lh42rgY5SxD2JmW4BP5bn9eELOM0/SuFJkl4mh+EZ5wdDIzC4FrsSfX12hukWQda7uxPc748g0jqgxnY7HgPlc71vAdYWUHERYAdyRoX0un+qFhxVTgK7ANyVFDSfnhQ/PI6cP/g4Ow489zS9zQ3EmcHSsrIb38sEtsXtry5BdDkbhey+L8lUwswk4kbIeOFPSzgz6ss7VTiSfjMk0jo6hQifc7b0g6a1Y4254SPNULCw6aCHpFxlF7DcmM3sBJx2exinpKLaF6zEkI0c+XInnb2PNbJWkmaV0IqleoMZH4PnC0lLkZIGZdcDD5sdCCJVU53JgBn7E60xJ25Lqlagv01w1s/b4ArS50uPI5Uyd8LAoiXefga+cqZmXNoioZ5qFL0qXJIQcrwKv47RtEqInHy7ASYhpZjaqst1tVgzBF4vE0MjMJuNzaC3OkqY2pICsc7VnaL82Vp55HDlj2gXsBrqb2UkRAd/B9zmgaTjxfsYmnO4ejZMON0pqcng3GNcy4CNhkziOvni40RDYvmHAy8B8MzuluTpfYdTiE/u++A0z+xGeqK/CV/KCWx1mVm9mjcG75kPWuZp7rvHN6MzjaJf7BMPMbsCPkbyOb3x1CQruB47CN6ZuAm6VtDJJ2PsJZrYRp9+3AT0KhAZfxdm5CZJmR8oPxSfFGkkDI+W98JxuDzCo3G2IrCcgzGwkvj0APgfOBl7EmTOA7ZImhrrtgJeArZIGx+RchG8678UZt6TnsyX6aYyZ3Y4n9HWS8m7YZ5mrZnYHvgh2zREUlRpHlICYhH/Hcz7wbXwFmAF8HzgZuB0YTwse/zjA0YAb0+R8hhSwEDe4C4HZkfLewCHEPkeRtD6EeQ8DD5nZoGIreoVRA1wUK+sW/sAnXW67oD/wSZwIiaNruHbAN1KT8AQ+UXPojW8T5CUAAlLNVTM7Cl8oHowxfRUZR7uW/jiwrcDMlgODgCPjiXBC3Sk4u9Q3KRw8WGFm1+FMZjdJm4vVLyLraHwzfJqkqyrQvSQd38U31k+TtDxSXpFxtMbHgQc9QljQB9hQzJACZuDnCNNueB+oqAX+ltWQAk7D6e7pFZDVBIEFnAIsjBpSQEXGUfVMKRBOgm/AD0OOKVY/tBkCnAH8qsDRoiqaCWZ2Ih4W1qtCH6HG0dIfB7YVnByuJYdskpbhzF4VrQBJz+OnxZsNVc9URRUVQjVnqqKKCuF/hg/HHMc7Mw0AAAAASUVORK5CYII=", + "text/latex": [ + "$\\displaystyle \\frac{\\alpha^{- k - \\frac{1}{2}} \\gamma\\left(k + \\frac{1}{2}, \\alpha\\right)}{2}$" + ], + "text/plain": [ + " -k - 1/2 \n", + "α ⋅γ(k + 1/2, α)\n", + "───────────────────────\n", + " 2 " + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "ZZ01 = Integral(x ** (2 * k) * exp(-a * x**2), (x, 0, 1))\n", + "ZZ01_expanded = simplify(ZZ01.doit())\n", + "display(ZZ01, ZZ01_expanded)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "hgei", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 6961f312c492fa4ee2773a9a7a6d04000a639925 Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Mon, 18 Sep 2023 13:41:53 +0000 Subject: [PATCH 10/13] updating unit test install --- .github/workflows/unittest.yaml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/.github/workflows/unittest.yaml b/.github/workflows/unittest.yaml index 67afe73..e10c219 100644 --- a/.github/workflows/unittest.yaml +++ b/.github/workflows/unittest.yaml @@ -21,9 +21,7 @@ jobs: - name: Install requirements run: | pip install -U pip - pip install -r requirements_ipu.txt - pip install -r requirements_test.txt - pip install -e . + pip install -e ".[test,ipu]" - name: Log installed environment run: | From 61ab7788ab643a0601fc9708e3c1e01baf9a265c Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Mon, 18 Sep 2023 13:42:45 +0000 Subject: [PATCH 11/13] add badge with unittest status --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index e8ae7a4..6455503 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,7 @@ [![notebook-tests](https://github.com/graphcore-research/pyscf-ipu/actions/workflows/notebooks.yaml/badge.svg)](https://github.com/graphcore-research/pyscf-ipu/actions/workflows/notebooks.yaml) [![nanoDFT CLI](https://github.com/graphcore-research/pyscf-ipu/actions/workflows/cli.yaml/badge.svg)](https://github.com/graphcore-research/pyscf-ipu/actions/workflows/cli.yaml) +[![unit tests](https://github.com/graphcore-research/pyscf-ipu/actions/workflows/unittest.yaml/badge.svg)](https://github.com/graphcore-research/pyscf-ipu/actions/workflows/unittest.yaml) [**Installation guide**](#installation) | [**Example DFT Computations**](#example-dft-computations) From 3afef1d3b9a6134f3fe5768f4daf504623fea178 Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Mon, 18 Sep 2023 14:45:35 +0100 Subject: [PATCH 12/13] inline lmax conditional suggested by @awf Co-authored-by: Andrew Fitzgibbon --- pyscf_ipu/experimental/integrals.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyscf_ipu/experimental/integrals.py b/pyscf_ipu/experimental/integrals.py index 8aa895a..10c9a7b 100644 --- a/pyscf_ipu/experimental/integrals.py +++ b/pyscf_ipu/experimental/integrals.py @@ -298,7 +298,8 @@ def lmax(i: int, j: int, k: int): for idx in range(n): for jdx in range(idx + 1): for kdx in range(idx + 1): - for ldx in range(lmax(idx, jdx, kdx) + 1): + lmax = jdx if idx == kdx else kdx + for ldx in range(lmax + 1): yield idx, jdx, kdx, ldx From 9de35438b8f893cb9c443718d13ae0f1b85a2fb9 Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Mon, 18 Sep 2023 13:49:44 +0000 Subject: [PATCH 13/13] adding comment on LMAX --- pyscf_ipu/experimental/integrals.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pyscf_ipu/experimental/integrals.py b/pyscf_ipu/experimental/integrals.py index 10c9a7b..e3e0613 100644 --- a/pyscf_ipu/experimental/integrals.py +++ b/pyscf_ipu/experimental/integrals.py @@ -15,6 +15,8 @@ from .primitive import Primitive, product from .types import IntN, FloatN, FloatNxN, Float3, FloatNx3 +# Maximum value an individual component of the angular momentum lmn can take +# Used for static ahead-of-time compilation of functions involving lmn. LMAX = 4 """ @@ -291,10 +293,6 @@ def gen_ijkl(n: int): """ adapted from four-index transformations by S Wilson pg 257 """ - - def lmax(i: int, j: int, k: int): - return j if i == k else k - for idx in range(n): for jdx in range(idx + 1): for kdx in range(idx + 1):