Skip to content

Commit

Permalink
Add function decorator for comparing fp64 to fp32 (#110)
Browse files Browse the repository at this point in the history
* wip: numeric comparison tooling

* review feedback

* move enable_x64 into decorator
  • Loading branch information
hatemhelal authored Sep 27, 2023
1 parent 6bbbf7e commit ec94c8b
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 1 deletion.
14 changes: 13 additions & 1 deletion pyscf_ipu/experimental/integrals.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,25 @@
"""


def factorial(n: IntN, nmax: int = LMAX) -> IntN:
def factorial_fori(n: IntN, nmax: int = LMAX) -> IntN:
def body_fun(i, val):
return val * jnp.where(i <= n, i, 1)

return lax.fori_loop(1, nmax + 1, body_fun, jnp.ones_like(n))


def factorial_gamma(n: IntN) -> IntN:
"""Appoximate factorial by evaluating the gamma function in log-space.
This approximation is exact for small integers (n < 10).
"""
approx = jnp.exp(gammaln(n + 1))
return jnp.rint(approx)


factorial = factorial_fori


def factorial2(n: IntN, nmax: int = 2 * LMAX) -> IntN:
def body_fun(i, val):
return val * jnp.where((i <= n) & (n % 2 == i % 2), i, 1)
Expand Down
47 changes: 47 additions & 0 deletions pyscf_ipu/experimental/numerics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from functools import wraps
from typing import Callable

import jax.numpy as jnp
import numpy as np
from jax.experimental import enable_x64
from jaxtyping import Array


def apply_fpcast(v: Array, dtype: np.dtype):
if isinstance(v, jnp.ndarray) and np.issubdtype(v, np.floating):
return v.astype(dtype)

return v


def fpcast(func: Callable, dtype=jnp.float32):
@wraps(func)
def wrapper(*args, **kwargs):
inputs = [apply_fpcast(v, dtype) for v in args]
outputs = func(*inputs, **kwargs)
return outputs

return wrapper


def compare_fp32_to_fp64(func: Callable):
@wraps(func)
def wrapper(*args, **kwargs):
with enable_x64():
outputs_fp32 = fpcast(func, dtype=jnp.float32)(*args, **kwargs)
outputs_fp64 = fpcast(func, dtype=jnp.float64)(*args, **kwargs)
print_compare(func.__name__, outputs_fp32, outputs_fp64)
return outputs_fp32

return wrapper


def print_compare(name: str, fp32, fp64):
fp32 = [fp32] if isinstance(fp32, jnp.ndarray) else fp32
fp64 = [fp64] if isinstance(fp64, jnp.ndarray) else fp64

for idx, (low, high) in enumerate(zip(fp32, fp64)):
low = np.asarray(low).astype(np.float64)
high = np.asarray(high)
print(f"{name} output {idx} has max |fp64 - fp32| = {np.abs(high - low).max()}")
14 changes: 14 additions & 0 deletions test/test_numerics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
import jax.numpy as jnp
from numpy.testing import assert_allclose

from pyscf_ipu.experimental.integrals import factorial_fori, factorial_gamma
from pyscf_ipu.experimental.numerics import compare_fp32_to_fp64


def test_factorial():
n = 16
x = jnp.arange(n, dtype=jnp.float32)
y_fori = compare_fp32_to_fp64(factorial_fori)(x, n)
y_gamma = compare_fp32_to_fp64(factorial_gamma)(x)
assert_allclose(y_fori, y_gamma, 1e-2)

0 comments on commit ec94c8b

Please sign in to comment.