Skip to content

Commit 1e707a5

Browse files
committed
wip: numeric comparison tooling
1 parent 6bbbf7e commit 1e707a5

File tree

3 files changed

+66
-1
lines changed

3 files changed

+66
-1
lines changed

pyscf_ipu/experimental/integrals.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,22 @@
2424
"""
2525

2626

27-
def factorial(n: IntN, nmax: int = LMAX) -> IntN:
27+
def factorial_fori(n: IntN, nmax: int = LMAX) -> IntN:
2828
def body_fun(i, val):
2929
return val * jnp.where(i <= n, i, 1)
3030

3131
return lax.fori_loop(1, nmax + 1, body_fun, jnp.ones_like(n))
3232

3333

34+
def factorial_gamma(n: IntN) -> IntN:
35+
from jax.scipy.special import gammaln
36+
37+
return jnp.exp(gammaln(n + 1))
38+
39+
40+
factorial = factorial_fori
41+
42+
3443
def factorial2(n: IntN, nmax: int = 2 * LMAX) -> IntN:
3544
def body_fun(i, val):
3645
return val * jnp.where((i <= n) & (n % 2 == i % 2), i, 1)

pyscf_ipu/experimental/numerics.py

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
2+
from functools import wraps
3+
4+
import jax.numpy as jnp
5+
import numpy as np
6+
7+
8+
def apply_fpcast(v, dtype):
9+
if isinstance(v, jnp.ndarray) and v.dtype.kind == "f":
10+
return v.astype(dtype)
11+
12+
return v
13+
14+
15+
def fpcast(func, dtype=jnp.float32):
16+
@wraps(func)
17+
def wrapper(*args, **kwargs):
18+
inputs = [apply_fpcast(v, dtype) for v in args]
19+
outputs = func(*inputs, **kwargs)
20+
return outputs
21+
22+
return wrapper
23+
24+
25+
def compare_fp32_to_fp64(func):
26+
@wraps(func)
27+
def wrapper(*args, **kwargs):
28+
outputs_fp32 = fpcast(func, dtype=jnp.float32)(*args, **kwargs)
29+
outputs_fp64 = fpcast(func, dtype=jnp.float64)(*args, **kwargs)
30+
print_compare(outputs_fp32, outputs_fp64)
31+
return outputs_fp32
32+
33+
return wrapper
34+
35+
36+
def print_compare(fp32, fp64):
37+
fp32 = [fp32] if isinstance(fp32, jnp.ndarray) else fp32
38+
fp64 = [fp64] if isinstance(fp64, jnp.ndarray) else fp64
39+
40+
for low, high in zip(fp32, fp64):
41+
low = np.asarray(low).astype(np.float64)
42+
high = np.asarray(high)
43+
print(f" max |fp64 - fp32| = {np.abs(high - low).max()}")

test/test_numerics.py

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
2+
import jax.numpy as jnp
3+
from numpy.testing import assert_allclose
4+
5+
from pyscf_ipu.experimental.integrals import factorial_fori, factorial_gamma
6+
from pyscf_ipu.experimental.numerics import compare_fp32_to_fp64
7+
8+
9+
def test_factorial():
10+
x = jnp.arange(8, dtype=jnp.float32)
11+
y_fori = compare_fp32_to_fp64(factorial_fori)(x, 8)
12+
y_gamma = compare_fp32_to_fp64(factorial_gamma)(x)
13+
assert_allclose(y_fori, y_gamma, atol=1e-2)

0 commit comments

Comments
 (0)