Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring special function implementations #112

Merged
merged 4 commits into from
Sep 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 26 additions & 79 deletions pyscf_ipu/experimental/integrals.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,101 +6,47 @@

import jax.numpy as jnp
import numpy as np
from jax import jit, lax, tree_map, vmap
from jax import jit, tree_map, vmap
from jax.ops import segment_sum
from jax.scipy.special import gammainc, gammaln

from .basis import Basis
from .orbital import batch_orbitals
from .primitive import Primitive, product
from .types import Float3, FloatN, FloatNx3, FloatNxN, IntN
from .special import binom, binom_factor, factorial, factorial2, gammanu
from .types import Float3, FloatN, FloatNx3, FloatNxN
from .units import LMAX

# Maximum value an individual component of the angular momentum lmn can take
# Used for static ahead-of-time compilation of functions involving lmn.
LMAX = 4

"""
Special functions used in integral evaluation
"""
JAX implementation for integrals over Gaussian basis functions. Based upon the
closed-form expressions derived in

Taketa, H., Huzinaga, S., & O-ohata, K. (1966). Gaussian-expansion methods for
molecular integrals. Journal of the physical society of Japan, 21(11), 2313-2324.
<https://doi.org/10.1143/JPSJ.21.2313>

def factorial_fori(n: IntN, nmax: int = LMAX) -> IntN:
def body_fun(i, val):
return val * jnp.where(i <= n, i, 1)

return lax.fori_loop(1, nmax + 1, body_fun, jnp.ones_like(n))


def factorial_gamma(n: IntN) -> IntN:
"""Appoximate factorial by evaluating the gamma function in log-space.

This approximation is exact for small integers (n < 10).
"""
approx = jnp.exp(gammaln(n + 1))
return jnp.rint(approx)


factorial = factorial_fori


def factorial2(n: IntN, nmax: int = 2 * LMAX) -> IntN:
def body_fun(i, val):
return val * jnp.where((i <= n) & (n % 2 == i % 2), i, 1)

return lax.fori_loop(1, nmax + 1, body_fun, jnp.ones_like(n))

Hereafter referred to as the "THO paper"

def binom(x: IntN, y: IntN, nmax: int = LMAX) -> IntN:
bang = partial(factorial, nmax=nmax)
c = x * bang(x - 1) / (bang(y) * bang(x - y))
return jnp.where(x == y, 1, c)
Related work:


def gammanu(nu: IntN, t: FloatN, epsilon: float = 1e-10) -> FloatN:
"""
eq 2.11 from THO but simplified using SymPy and converted to jax

t, u = symbols("t u", real=True, positive=True)
nu = Symbol("nu", integer=True, nonnegative=True)

expr = simplify(integrate(u ** (2 * nu) * exp(-t * u**2), (u, 0, 1)))
f = lambdify((nu, t), expr, modules="scipy")
?f

We evaulate this in log-space to avoid overflow/nan
"""
t = jnp.maximum(t, epsilon)
x = nu + 0.5
gn = jnp.log(0.5) - x * jnp.log(t) + jnp.log(gammainc(x, t)) + gammaln(x)
return jnp.exp(gn)


@partial(vmap, in_axes=(0, None, None, None, None))
def binom_factor(s: IntN, i: int, j: int, pa: float, pb: float):
"""
Eq. 15 from Augspurger JD, Dykstra CE. General quantum mechanical operators. An
[1] Augspurger JD, Dykstra CE. General quantum mechanical operators. An
open-ended approach for one-electron integrals with Gaussian bases. Journal of
computational chemistry. 1990 Jan;11(1):105-11.
<https://doi.org/10.1002/jcc.540110113>
"""

def term(t):
return binom(i, s - t) * binom(j, t) * pa ** (i - s + t) * pb ** (j - t)

def body_fun(t, val):
mask = (t <= s) & (t >= (s - i)) & (t <= j)
return val + jnp.where(mask, term(t), 0.0)

return lax.fori_loop(0, LMAX + 1, body_fun, 0.0)
[2] PyQuante: <https://github.com/rpmuller/pyquante2/>
"""


@partial(vmap, in_axes=(0, 0, 0, 0, None))
def overlap_axis(i: int, j: int, pa: int, pb: int, alpha: float) -> float:
ii = jnp.arange(LMAX + 1)
out = binom_factor(2 * ii, i, j, pa, pb)
out *= factorial2(2 * ii - 1) / (2 * alpha) ** ii
mask = ii <= jnp.floor_divide(i + j, 2)
out = jnp.where(mask, out, 0.0)
def overlap_axis(i: int, j: int, a: float, b: float, alpha: float) -> float:
idx = [(s, t) for s in range(LMAX + 1) for t in range(2 * s + 1)]
s, t = jnp.array(idx, dtype=jnp.uint32).T
out = binom(i, 2 * s - t) * binom(j, t)
out *= a ** (i - (2 * s - t)) * b ** (j - t)
out *= factorial2(2 * s - 1) / (2 * alpha) ** s

mask = (2 * s - i <= t) & (t <= j)
out = jnp.where(mask, out, 0)
return jnp.sum(out)


Expand Down Expand Up @@ -179,7 +125,7 @@ def g_term(l1, l2, pa, pb, cp):
index = i - 2 * r - u
g = (
jnp.power(-1, i + u)
* binom_factor(i, l1, l2, pa, pb)
* jnp.take(binom_factor(l1, l2, pa, pb), i)
* factorial(i)
* jnp.power(cp, index - u)
* jnp.power(epsilon, r + u)
Expand Down Expand Up @@ -251,7 +197,7 @@ def H(l1, l2, a, b, i, r, gamma):
# Note this should match THO Eq 3.5 but that seems to incorrectly show a
# 1/(4 gamma) ^(i- 2r) term which is inconsistent with Eq 2.22.
# Using (4 gamma)^(r - i) matches the reported expressions for H_L
u = factorial(i) * binom_factor(i, l1, l2, a, b)
u = factorial(i) * jnp.take(binom_factor(l1, l2, a, b, 2 * LMAX), i)
v = factorial(r) * factorial(i - 2 * r) * (4 * gamma) ** (i - r)
return u / v

Expand All @@ -269,6 +215,7 @@ def c_term(la, lb, lc, ld, pa, pb, qc, qd, qp):
return segment_sum(c, index, num_segments=4 * LMAX + 1)

# Manual vmap over cartesian axes (x, y, z) as ran into possible bug.
# See https://github.com/graphcore-research/pyscf-ipu/issues/105
args = [a.lmn, b.lmn, c.lmn, d.lmn, pa, pb, qc, qd, qp]
Ci, Cj, Ck = [c_term(*[v.at[i].get() for v in args]) for i in range(3)]

Expand Down
110 changes: 110 additions & 0 deletions pyscf_ipu/experimental/special.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from functools import partial

import jax.numpy as jnp
import numpy as np
from jax import lax
from jax.ops import segment_sum
from jax.scipy.special import betaln, gammainc, gammaln

from .types import FloatN, IntN
from .units import LMAX


def factorial_fori(n: IntN, nmax: int = LMAX) -> IntN:
def body_fun(i, val):
return val * jnp.where(i <= n, i, 1)

return lax.fori_loop(1, nmax + 1, body_fun, jnp.ones_like(n))


def factorial_gamma(n: IntN) -> IntN:
"""Appoximate factorial by evaluating the gamma function in log-space.

This approximation is exact for small integers (n < 10).
"""
approx = jnp.exp(gammaln(n + 1))
return jnp.rint(approx)


def factorial_lookup(n: IntN, nmax: int = LMAX) -> IntN:
N = np.cumprod(np.arange(1, nmax + 1))
N = np.insert(N, 0, 1)
N = jnp.array(N, dtype=jnp.uint32)
return N.at[n.astype(jnp.uint32)].get()


factorial = factorial_gamma


def factorial2_fori(n: IntN, nmax: int = 2 * LMAX) -> IntN:
def body_fun(i, val):
return val * jnp.where((i <= n) & (n % 2 == i % 2), i, 1)

return lax.fori_loop(1, nmax + 1, body_fun, jnp.ones_like(n))


def factorial2_lookup(n: IntN, nmax: int = 2 * LMAX) -> IntN:
stop = nmax + 1 if nmax % 2 == 0 else nmax + 2
N = np.arange(1, stop).reshape(-1, 2)
N = np.cumprod(N, axis=0).reshape(-1)
N = np.insert(N, 0, 1)
N = jnp.array(N)
n = jnp.maximum(n, 0)
return N.at[n].get()


factorial2 = factorial2_lookup


def binom_beta(x: IntN, y: IntN) -> IntN:
approx = 1.0 / ((x + 1) * jnp.exp(betaln(x - y + 1, y + 1)))
return jnp.rint(approx)


def binom_fori(x: IntN, y: IntN, nmax: int = LMAX) -> IntN:
bang = partial(factorial_fori, nmax=nmax)
c = x * bang(x - 1) / (bang(y) * bang(x - y))
return jnp.where(x == y, 1, c)


def binom_lookup(x: IntN, y: IntN, nmax: int = LMAX) -> IntN:
bang = partial(factorial_lookup, nmax=nmax)
c = x * bang(x - 1) / (bang(y) * bang(x - y))
return jnp.where(x == y, 1, c)


binom = binom_lookup


def gammanu(nu: IntN, t: FloatN, epsilon: float = 1e-10) -> FloatN:
"""
eq 2.11 from THO but simplified using SymPy and converted to jax

t, u = symbols("t u", real=True, positive=True)
nu = Symbol("nu", integer=True, nonnegative=True)

expr = simplify(integrate(u ** (2 * nu) * exp(-t * u**2), (u, 0, 1)))
f = lambdify((nu, t), expr, modules="scipy")
?f

We evaulate this in log-space to avoid overflow/nan
"""
t = jnp.maximum(t, epsilon)
x = nu + 0.5
gn = jnp.log(0.5) - x * jnp.log(t) + jnp.log(gammainc(x, t)) + gammaln(x)
return jnp.exp(gn)


def binom_factor(i: int, j: int, a: float, b: float, lmax: int = LMAX) -> FloatN:
"""
Eq. 15 from Augspurger JD, Dykstra CE. General quantum mechanical operators. An
open-ended approach for one-electron integrals with Gaussian bases. Journal of
computational chemistry. 1990 Jan;11(1):105-11.
<https://doi.org/10.1002/jcc.540110113>
"""
s, t = jnp.tril_indices(lmax + 1)
out = binom(i, s - t) * binom(j, t) * a ** (i - (s - t)) * b ** (j - t)
mask = ((s - i) <= t) & (t <= j)
out = jnp.where(mask, out, 0.0)
return segment_sum(out, s, num_segments=lmax + 1)
4 changes: 4 additions & 0 deletions pyscf_ipu/experimental/units.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from jaxtyping import Array

# Maximum value an individual component of the angular momentum lmn can take
# Used for static ahead-of-time compilation of functions involving lmn.
LMAX = 4

BOHR_PER_ANGSTROM = 0.529177210903


Expand Down
51 changes: 44 additions & 7 deletions test/test_numerics.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,51 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
import jax.numpy as jnp
import pytest
from numpy.testing import assert_allclose

from pyscf_ipu.experimental.integrals import factorial_fori, factorial_gamma
from pyscf_ipu.experimental.numerics import compare_fp32_to_fp64
from pyscf_ipu.experimental.special import (
binom_beta,
binom_fori,
binom_lookup,
factorial2_fori,
factorial2_lookup,
factorial_fori,
factorial_gamma,
factorial_lookup,
)


def test_factorial():
n = 16
x = jnp.arange(n, dtype=jnp.float32)
y_fori = compare_fp32_to_fp64(factorial_fori)(x, n)
y_gamma = compare_fp32_to_fp64(factorial_gamma)(x)
assert_allclose(y_fori, y_gamma, 1e-2)
x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8])
expect = jnp.array([1, 2, 6, 24, 120, 720, 5040, 40320])
assert_allclose(factorial_fori(x, x[-1]), expect)
assert_allclose(factorial_lookup(x, x[-1]), expect)
assert_allclose(factorial_gamma(x), expect)


def test_factorial2():
x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8])
expect = jnp.array([1, 2, 3, 8, 15, 48, 105, 384])
assert_allclose(factorial2_fori(x), expect)
assert_allclose(factorial2_fori(0), 1)

assert_allclose(factorial2_lookup(x), expect)
assert_allclose(factorial2_lookup(0), 1)


@pytest.mark.parametrize("binom_func", [binom_beta, binom_fori, binom_lookup])
def test_binom(binom_func):
x = jnp.array([4, 4, 4, 4])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe test some edge cases too (1,1) (3,3) (3,0) (0,-1) etc

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea, I think $3 \choose 3$ and $3 \choose 0$ should already be covered but I've added some more cases involving ones, zeros, and minus ones.

y = jnp.array([1, 2, 3, 4])
expect = jnp.array([4, 6, 4, 1])
assert_allclose(binom_func(x, y), expect)

zero = jnp.array([0])
assert_allclose(binom_func(zero, y), jnp.zeros_like(x))
assert_allclose(binom_func(x, zero), jnp.ones_like(y))
assert_allclose(binom_func(y, y), jnp.ones_like(y))

one = jnp.array([1])
assert_allclose(binom_func(one, one), one)
assert_allclose(binom_func(zero, -one), zero)
assert_allclose(binom_func(zero, zero), one)