diff --git a/pyscf_ipu/experimental/integrals.py b/pyscf_ipu/experimental/integrals.py index b9ebd34..c6f0dfa 100644 --- a/pyscf_ipu/experimental/integrals.py +++ b/pyscf_ipu/experimental/integrals.py @@ -24,13 +24,25 @@ """ -def factorial(n: IntN, nmax: int = LMAX) -> IntN: +def factorial_fori(n: IntN, nmax: int = LMAX) -> IntN: def body_fun(i, val): return val * jnp.where(i <= n, i, 1) return lax.fori_loop(1, nmax + 1, body_fun, jnp.ones_like(n)) +def factorial_gamma(n: IntN) -> IntN: + """Appoximate factorial by evaluating the gamma function in log-space. + + This approximation is exact for small integers (n < 10). + """ + approx = jnp.exp(gammaln(n + 1)) + return jnp.rint(approx) + + +factorial = factorial_fori + + def factorial2(n: IntN, nmax: int = 2 * LMAX) -> IntN: def body_fun(i, val): return val * jnp.where((i <= n) & (n % 2 == i % 2), i, 1) diff --git a/pyscf_ipu/experimental/numerics.py b/pyscf_ipu/experimental/numerics.py new file mode 100644 index 0000000..066e159 --- /dev/null +++ b/pyscf_ipu/experimental/numerics.py @@ -0,0 +1,47 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +from functools import wraps +from typing import Callable + +import jax.numpy as jnp +import numpy as np +from jax.experimental import enable_x64 +from jaxtyping import Array + + +def apply_fpcast(v: Array, dtype: np.dtype): + if isinstance(v, jnp.ndarray) and np.issubdtype(v, np.floating): + return v.astype(dtype) + + return v + + +def fpcast(func: Callable, dtype=jnp.float32): + @wraps(func) + def wrapper(*args, **kwargs): + inputs = [apply_fpcast(v, dtype) for v in args] + outputs = func(*inputs, **kwargs) + return outputs + + return wrapper + + +def compare_fp32_to_fp64(func: Callable): + @wraps(func) + def wrapper(*args, **kwargs): + with enable_x64(): + outputs_fp32 = fpcast(func, dtype=jnp.float32)(*args, **kwargs) + outputs_fp64 = fpcast(func, dtype=jnp.float64)(*args, **kwargs) + print_compare(func.__name__, outputs_fp32, outputs_fp64) + return outputs_fp32 + + return wrapper + + +def print_compare(name: str, fp32, fp64): + fp32 = [fp32] if isinstance(fp32, jnp.ndarray) else fp32 + fp64 = [fp64] if isinstance(fp64, jnp.ndarray) else fp64 + + for idx, (low, high) in enumerate(zip(fp32, fp64)): + low = np.asarray(low).astype(np.float64) + high = np.asarray(high) + print(f"{name} output {idx} has max |fp64 - fp32| = {np.abs(high - low).max()}") diff --git a/test/test_numerics.py b/test/test_numerics.py new file mode 100644 index 0000000..1f52af4 --- /dev/null +++ b/test/test_numerics.py @@ -0,0 +1,14 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +import jax.numpy as jnp +from numpy.testing import assert_allclose + +from pyscf_ipu.experimental.integrals import factorial_fori, factorial_gamma +from pyscf_ipu.experimental.numerics import compare_fp32_to_fp64 + + +def test_factorial(): + n = 16 + x = jnp.arange(n, dtype=jnp.float32) + y_fori = compare_fp32_to_fp64(factorial_fori)(x, n) + y_gamma = compare_fp32_to_fp64(factorial_gamma)(x) + assert_allclose(y_fori, y_gamma, 1e-2)