Skip to content

Commit

Permalink
review feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
hatemhelal committed Sep 27, 2023
1 parent 1e707a5 commit 360cdd4
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 14 deletions.
7 changes: 5 additions & 2 deletions pyscf_ipu/experimental/integrals.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@ def body_fun(i, val):


def factorial_gamma(n: IntN) -> IntN:
from jax.scipy.special import gammaln
"""Appoximate factorial by evaluating the gamma function in log-space.
return jnp.exp(gammaln(n + 1))
This approximation is exact for small integers (n < 10).
"""
approx = jnp.exp(gammaln(n + 1))
return jnp.rint(approx)


factorial = factorial_fori
Expand Down
18 changes: 10 additions & 8 deletions pyscf_ipu/experimental/numerics.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from functools import wraps
from typing import Callable

import jax.numpy as jnp
import numpy as np
from jaxtyping import Array


def apply_fpcast(v, dtype):
if isinstance(v, jnp.ndarray) and v.dtype.kind == "f":
def apply_fpcast(v: Array, dtype: np.dtype):
if isinstance(v, jnp.ndarray) and np.issubdtype(v, np.floating):
return v.astype(dtype)

return v


def fpcast(func, dtype=jnp.float32):
def fpcast(func: Callable, dtype=jnp.float32):
@wraps(func)
def wrapper(*args, **kwargs):
inputs = [apply_fpcast(v, dtype) for v in args]
Expand All @@ -22,22 +24,22 @@ def wrapper(*args, **kwargs):
return wrapper


def compare_fp32_to_fp64(func):
def compare_fp32_to_fp64(func: Callable):
@wraps(func)
def wrapper(*args, **kwargs):
outputs_fp32 = fpcast(func, dtype=jnp.float32)(*args, **kwargs)
outputs_fp64 = fpcast(func, dtype=jnp.float64)(*args, **kwargs)
print_compare(outputs_fp32, outputs_fp64)
print_compare(func.__name__, outputs_fp32, outputs_fp64)
return outputs_fp32

return wrapper


def print_compare(fp32, fp64):
def print_compare(name: str, fp32, fp64):
fp32 = [fp32] if isinstance(fp32, jnp.ndarray) else fp32
fp64 = [fp64] if isinstance(fp64, jnp.ndarray) else fp64

for low, high in zip(fp32, fp64):
for idx, (low, high) in enumerate(zip(fp32, fp64)):
low = np.asarray(low).astype(np.float64)
high = np.asarray(high)
print(f" max |fp64 - fp32| = {np.abs(high - low).max()}")
print(f"{name} output {idx} has max |fp64 - fp32| = {np.abs(high - low).max()}")
11 changes: 7 additions & 4 deletions test/test_numerics.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
import jax.numpy as jnp
from jax.experimental import enable_x64
from numpy.testing import assert_allclose

from pyscf_ipu.experimental.integrals import factorial_fori, factorial_gamma
from pyscf_ipu.experimental.numerics import compare_fp32_to_fp64


def test_factorial():
x = jnp.arange(8, dtype=jnp.float32)
y_fori = compare_fp32_to_fp64(factorial_fori)(x, 8)
y_gamma = compare_fp32_to_fp64(factorial_gamma)(x)
assert_allclose(y_fori, y_gamma, atol=1e-2)
with enable_x64():
n = 16
x = jnp.arange(n, dtype=jnp.float32)
y_fori = compare_fp32_to_fp64(factorial_fori)(x, n)
y_gamma = compare_fp32_to_fp64(factorial_gamma)(x)
assert_allclose(y_fori, y_gamma, 1e-2)

0 comments on commit 360cdd4

Please sign in to comment.