-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactoring special function implementations #112
Merged
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
from functools import partial | ||
|
||
import jax.numpy as jnp | ||
import numpy as np | ||
from jax import lax | ||
from jax.ops import segment_sum | ||
from jax.scipy.special import betaln, gammainc, gammaln | ||
|
||
from .types import FloatN, IntN | ||
from .units import LMAX | ||
|
||
|
||
def factorial_fori(n: IntN, nmax: int = LMAX) -> IntN: | ||
def body_fun(i, val): | ||
return val * jnp.where(i <= n, i, 1) | ||
|
||
return lax.fori_loop(1, nmax + 1, body_fun, jnp.ones_like(n)) | ||
|
||
|
||
def factorial_gamma(n: IntN) -> IntN: | ||
"""Appoximate factorial by evaluating the gamma function in log-space. | ||
|
||
This approximation is exact for small integers (n < 10). | ||
""" | ||
approx = jnp.exp(gammaln(n + 1)) | ||
return jnp.rint(approx) | ||
|
||
|
||
def factorial_lookup(n: IntN, nmax: int = LMAX) -> IntN: | ||
N = np.cumprod(np.arange(1, nmax + 1)) | ||
N = np.insert(N, 0, 1) | ||
N = jnp.array(N, dtype=jnp.uint32) | ||
return N.at[n.astype(jnp.uint32)].get() | ||
|
||
|
||
factorial = factorial_gamma | ||
|
||
|
||
def factorial2_fori(n: IntN, nmax: int = 2 * LMAX) -> IntN: | ||
def body_fun(i, val): | ||
return val * jnp.where((i <= n) & (n % 2 == i % 2), i, 1) | ||
|
||
return lax.fori_loop(1, nmax + 1, body_fun, jnp.ones_like(n)) | ||
|
||
|
||
def factorial2_lookup(n: IntN, nmax: int = 2 * LMAX) -> IntN: | ||
stop = nmax + 1 if nmax % 2 == 0 else nmax + 2 | ||
N = np.arange(1, stop).reshape(-1, 2) | ||
N = np.cumprod(N, axis=0).reshape(-1) | ||
N = np.insert(N, 0, 1) | ||
N = jnp.array(N) | ||
n = jnp.maximum(n, 0) | ||
return N.at[n].get() | ||
|
||
|
||
factorial2 = factorial2_lookup | ||
|
||
|
||
def binom_beta(x: IntN, y: IntN) -> IntN: | ||
approx = 1.0 / ((x + 1) * jnp.exp(betaln(x - y + 1, y + 1))) | ||
return jnp.rint(approx) | ||
|
||
|
||
def binom_fori(x: IntN, y: IntN, nmax: int = LMAX) -> IntN: | ||
bang = partial(factorial_fori, nmax=nmax) | ||
c = x * bang(x - 1) / (bang(y) * bang(x - y)) | ||
return jnp.where(x == y, 1, c) | ||
|
||
|
||
def binom_lookup(x: IntN, y: IntN, nmax: int = LMAX) -> IntN: | ||
bang = partial(factorial_lookup, nmax=nmax) | ||
c = x * bang(x - 1) / (bang(y) * bang(x - y)) | ||
return jnp.where(x == y, 1, c) | ||
|
||
|
||
binom = binom_lookup | ||
|
||
|
||
def gammanu(nu: IntN, t: FloatN, epsilon: float = 1e-10) -> FloatN: | ||
""" | ||
eq 2.11 from THO but simplified using SymPy and converted to jax | ||
|
||
t, u = symbols("t u", real=True, positive=True) | ||
nu = Symbol("nu", integer=True, nonnegative=True) | ||
|
||
expr = simplify(integrate(u ** (2 * nu) * exp(-t * u**2), (u, 0, 1))) | ||
f = lambdify((nu, t), expr, modules="scipy") | ||
?f | ||
|
||
We evaulate this in log-space to avoid overflow/nan | ||
""" | ||
t = jnp.maximum(t, epsilon) | ||
x = nu + 0.5 | ||
gn = jnp.log(0.5) - x * jnp.log(t) + jnp.log(gammainc(x, t)) + gammaln(x) | ||
return jnp.exp(gn) | ||
|
||
|
||
def binom_factor(i: int, j: int, a: float, b: float, lmax: int = LMAX) -> FloatN: | ||
""" | ||
Eq. 15 from Augspurger JD, Dykstra CE. General quantum mechanical operators. An | ||
open-ended approach for one-electron integrals with Gaussian bases. Journal of | ||
computational chemistry. 1990 Jan;11(1):105-11. | ||
<https://doi.org/10.1002/jcc.540110113> | ||
""" | ||
s, t = jnp.tril_indices(lmax + 1) | ||
out = binom(i, s - t) * binom(j, t) * a ** (i - (s - t)) * b ** (j - t) | ||
mask = ((s - i) <= t) & (t <= j) | ||
out = jnp.where(mask, out, 0.0) | ||
return segment_sum(out, s, num_segments=lmax + 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,51 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
import jax.numpy as jnp | ||
import pytest | ||
from numpy.testing import assert_allclose | ||
|
||
from pyscf_ipu.experimental.integrals import factorial_fori, factorial_gamma | ||
from pyscf_ipu.experimental.numerics import compare_fp32_to_fp64 | ||
from pyscf_ipu.experimental.special import ( | ||
binom_beta, | ||
binom_fori, | ||
binom_lookup, | ||
factorial2_fori, | ||
factorial2_lookup, | ||
factorial_fori, | ||
factorial_gamma, | ||
factorial_lookup, | ||
) | ||
|
||
|
||
def test_factorial(): | ||
n = 16 | ||
x = jnp.arange(n, dtype=jnp.float32) | ||
y_fori = compare_fp32_to_fp64(factorial_fori)(x, n) | ||
y_gamma = compare_fp32_to_fp64(factorial_gamma)(x) | ||
assert_allclose(y_fori, y_gamma, 1e-2) | ||
x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8]) | ||
expect = jnp.array([1, 2, 6, 24, 120, 720, 5040, 40320]) | ||
assert_allclose(factorial_fori(x, x[-1]), expect) | ||
assert_allclose(factorial_lookup(x, x[-1]), expect) | ||
assert_allclose(factorial_gamma(x), expect) | ||
|
||
|
||
def test_factorial2(): | ||
x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8]) | ||
expect = jnp.array([1, 2, 3, 8, 15, 48, 105, 384]) | ||
assert_allclose(factorial2_fori(x), expect) | ||
assert_allclose(factorial2_fori(0), 1) | ||
|
||
assert_allclose(factorial2_lookup(x), expect) | ||
assert_allclose(factorial2_lookup(0), 1) | ||
|
||
|
||
@pytest.mark.parametrize("binom_func", [binom_beta, binom_fori, binom_lookup]) | ||
def test_binom(binom_func): | ||
x = jnp.array([4, 4, 4, 4]) | ||
y = jnp.array([1, 2, 3, 4]) | ||
expect = jnp.array([4, 6, 4, 1]) | ||
assert_allclose(binom_func(x, y), expect) | ||
|
||
zero = jnp.array([0]) | ||
assert_allclose(binom_func(zero, y), jnp.zeros_like(x)) | ||
assert_allclose(binom_func(x, zero), jnp.ones_like(y)) | ||
assert_allclose(binom_func(y, y), jnp.ones_like(y)) | ||
|
||
one = jnp.array([1]) | ||
assert_allclose(binom_func(one, one), one) | ||
assert_allclose(binom_func(zero, -one), zero) | ||
assert_allclose(binom_func(zero, zero), one) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe test some edge cases too (1,1) (3,3) (3,0) (0,-1) etc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good idea, I think$3 \choose 3$ and $3 \choose 0$ should already be covered but I've added some more cases involving ones, zeros, and minus ones.