Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function decorator for comparing fp64 to fp32 #110

Merged
merged 3 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion pyscf_ipu/experimental/integrals.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,25 @@
"""


def factorial(n: IntN, nmax: int = LMAX) -> IntN:
def factorial_fori(n: IntN, nmax: int = LMAX) -> IntN:
def body_fun(i, val):
return val * jnp.where(i <= n, i, 1)

return lax.fori_loop(1, nmax + 1, body_fun, jnp.ones_like(n))


def factorial_gamma(n: IntN) -> IntN:
"""Appoximate factorial by evaluating the gamma function in log-space.

This approximation is exact for small integers (n < 10).
"""
approx = jnp.exp(gammaln(n + 1))
return jnp.rint(approx)


factorial = factorial_fori


def factorial2(n: IntN, nmax: int = 2 * LMAX) -> IntN:
def body_fun(i, val):
return val * jnp.where((i <= n) & (n % 2 == i % 2), i, 1)
Expand Down
47 changes: 47 additions & 0 deletions pyscf_ipu/experimental/numerics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from functools import wraps
from typing import Callable

import jax.numpy as jnp
import numpy as np
from jax.experimental import enable_x64
from jaxtyping import Array


def apply_fpcast(v: Array, dtype: np.dtype):
if isinstance(v, jnp.ndarray) and np.issubdtype(v, np.floating):
return v.astype(dtype)

return v


def fpcast(func: Callable, dtype=jnp.float32):
@wraps(func)
def wrapper(*args, **kwargs):
inputs = [apply_fpcast(v, dtype) for v in args]
outputs = func(*inputs, **kwargs)
return outputs

return wrapper


def compare_fp32_to_fp64(func: Callable):
@wraps(func)
def wrapper(*args, **kwargs):
with enable_x64():
outputs_fp32 = fpcast(func, dtype=jnp.float32)(*args, **kwargs)
outputs_fp64 = fpcast(func, dtype=jnp.float64)(*args, **kwargs)
print_compare(func.__name__, outputs_fp32, outputs_fp64)
return outputs_fp32

return wrapper


def print_compare(name: str, fp32, fp64):
fp32 = [fp32] if isinstance(fp32, jnp.ndarray) else fp32
fp64 = [fp64] if isinstance(fp64, jnp.ndarray) else fp64

for idx, (low, high) in enumerate(zip(fp32, fp64)):
low = np.asarray(low).astype(np.float64)
high = np.asarray(high)
print(f"{name} output {idx} has max |fp64 - fp32| = {np.abs(high - low).max()}")
14 changes: 14 additions & 0 deletions test/test_numerics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
import jax.numpy as jnp
from numpy.testing import assert_allclose

from pyscf_ipu.experimental.integrals import factorial_fori, factorial_gamma
from pyscf_ipu.experimental.numerics import compare_fp32_to_fp64


def test_factorial():
n = 16
x = jnp.arange(n, dtype=jnp.float32)
y_fori = compare_fp32_to_fp64(factorial_fori)(x, n)
y_gamma = compare_fp32_to_fp64(factorial_gamma)(x)
assert_allclose(y_fori, y_gamma, 1e-2)