Skip to content

Commit

Permalink
adding gammanu_series and gammanu_gamma as two impls for gammanu
Browse files Browse the repository at this point in the history
  • Loading branch information
hatemhelal committed Oct 9, 2023
1 parent 8ec5bc7 commit 70c9fa0
Showing 1 changed file with 23 additions and 2 deletions.
25 changes: 23 additions & 2 deletions pyscf_ipu/experimental/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
from jax import lax
from jax.ops import segment_sum
from jax.scipy.special import betaln, gammaln
from jax.scipy.special import betaln, gammainc, gammaln

from .types import FloatN, IntN
from .units import LMAX
Expand Down Expand Up @@ -77,7 +77,25 @@ def binom_lookup(x: IntN, y: IntN, nmax: int = LMAX) -> IntN:
binom = binom_lookup


def gammanu(nu: IntN, t: FloatN, num_terms: int = 128) -> FloatN:
def gammanu_gamma(nu: IntN, t: FloatN, epsilon: float = 1e-10) -> FloatN:
"""
eq 2.11 from THO but simplified using SymPy and converted to jax
t, u = symbols("t u", real=True, positive=True)
nu = Symbol("nu", integer=True, nonnegative=True)
expr = simplify(integrate(u ** (2 * nu) * exp(-t * u**2), (u, 0, 1)))
f = lambdify((nu, t), expr, modules="scipy")
?f
We evaulate this in log-space to avoid overflow/nan
"""
x = nu + 0.5
gn = jnp.log(0.5) - x * jnp.log(t) + jnp.log(gammainc(x, t)) + gammaln(x)
return jnp.where(t <= epsilon, 1 / (2 * nu + 1), jnp.exp(gn))


def gammanu_series(nu: IntN, t: FloatN, num_terms: int = 128) -> FloatN:
"""
eq 2.11 from THO but simplified as derived in equation 19 of gammanu.ipynb
"""
Expand All @@ -93,6 +111,9 @@ def gammanu(nu: IntN, t: FloatN, num_terms: int = 128) -> FloatN:
return jnp.exp(-t) / 2 * total


gammanu = gammanu_series


def binom_factor(i: int, j: int, a: float, b: float, lmax: int = LMAX) -> FloatN:
"""
Eq. 15 from Augspurger JD, Dykstra CE. General quantum mechanical operators. An
Expand Down

0 comments on commit 70c9fa0

Please sign in to comment.