Skip to content

Commit

Permalink
move enable_x64 into decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
hatemhelal committed Sep 27, 2023
1 parent 360cdd4 commit ce02159
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
10 changes: 6 additions & 4 deletions pyscf_ipu/experimental/numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import jax.numpy as jnp
import numpy as np
from jax.experimental import enable_x64
from jaxtyping import Array


Expand All @@ -27,10 +28,11 @@ def wrapper(*args, **kwargs):
def compare_fp32_to_fp64(func: Callable):
@wraps(func)
def wrapper(*args, **kwargs):
outputs_fp32 = fpcast(func, dtype=jnp.float32)(*args, **kwargs)
outputs_fp64 = fpcast(func, dtype=jnp.float64)(*args, **kwargs)
print_compare(func.__name__, outputs_fp32, outputs_fp64)
return outputs_fp32
with enable_x64():
outputs_fp32 = fpcast(func, dtype=jnp.float32)(*args, **kwargs)
outputs_fp64 = fpcast(func, dtype=jnp.float64)(*args, **kwargs)
print_compare(func.__name__, outputs_fp32, outputs_fp64)
return outputs_fp32

return wrapper

Expand Down
12 changes: 5 additions & 7 deletions test/test_numerics.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
import jax.numpy as jnp
from jax.experimental import enable_x64
from numpy.testing import assert_allclose

from pyscf_ipu.experimental.integrals import factorial_fori, factorial_gamma
from pyscf_ipu.experimental.numerics import compare_fp32_to_fp64


def test_factorial():
with enable_x64():
n = 16
x = jnp.arange(n, dtype=jnp.float32)
y_fori = compare_fp32_to_fp64(factorial_fori)(x, n)
y_gamma = compare_fp32_to_fp64(factorial_gamma)(x)
assert_allclose(y_fori, y_gamma, 1e-2)
n = 16
x = jnp.arange(n, dtype=jnp.float32)
y_fori = compare_fp32_to_fp64(factorial_fori)(x, n)
y_gamma = compare_fp32_to_fp64(factorial_gamma)(x)
assert_allclose(y_fori, y_gamma, 1e-2)

0 comments on commit ce02159

Please sign in to comment.