From 0e986f68e9ab46bb6b78dd989db8ea91a8a0a5bd Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Tue, 17 Oct 2023 10:01:03 +0100 Subject: [PATCH] Update basisset to use numpy over jax.numpy (#126) * Update basisset to use numpy over jax.numpy * change lmn to use signed int32 * use jnp in primitive product method to support jit --- pyscf_ipu/experimental/basis.py | 7 ++++--- pyscf_ipu/experimental/primitive.py | 13 +++++++------ pyscf_ipu/experimental/structure.py | 3 +++ 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/pyscf_ipu/experimental/basis.py b/pyscf_ipu/experimental/basis.py index eb7ebad..7ff73b6 100644 --- a/pyscf_ipu/experimental/basis.py +++ b/pyscf_ipu/experimental/basis.py @@ -3,6 +3,7 @@ import chex import jax.numpy as jnp +import numpy as np from .orbital import Orbital from .structure import Structure @@ -61,9 +62,9 @@ def basisset(structure: Structure, basis_name: str = "sto-3g"): for lmn in LMN_MAP[s["angular_momentum"][0]]: ao = Orbital.from_bse( center=center, - alphas=jnp.array(s["exponents"], dtype=float), - lmn=jnp.array(lmn, dtype=jnp.int32), - coefficients=jnp.array(s["coefficients"], dtype=float), + alphas=np.array(s["exponents"], dtype=np.float32), + lmn=np.array(lmn, dtype=np.int32), + coefficients=np.array(s["coefficients"], dtype=np.float32), ) orbitals.append(ao) diff --git a/pyscf_ipu/experimental/primitive.py b/pyscf_ipu/experimental/primitive.py index 283ac9e..1113b5a 100644 --- a/pyscf_ipu/experimental/primitive.py +++ b/pyscf_ipu/experimental/primitive.py @@ -3,16 +3,17 @@ import chex import jax.numpy as jnp -from jax.scipy.special import gammaln +import numpy as np +from scipy.special import gammaln from .types import Float3, FloatN, FloatNx3, Int3 @chex.dataclass class Primitive: - center: Float3 = jnp.zeros(3, dtype=jnp.float32) + center: Float3 = np.zeros(3, dtype=np.float32) alpha: float = 1.0 - lmn: Int3 = jnp.zeros(3, dtype=jnp.int32) + lmn: Int3 = np.zeros(3, dtype=np.int32) norm: Optional[float] = None def __post_init__(self): @@ -21,16 +22,16 @@ def __post_init__(self): @property def angular_momentum(self) -> int: - return jnp.sum(self.lmn) + return np.sum(self.lmn) def __call__(self, pos: FloatNx3) -> FloatN: return eval_primitive(self, pos) def normalize(lmn: Int3, alpha: float) -> float: - L = jnp.sum(lmn) + L = np.sum(lmn) N = ((1 / 2) / alpha) ** (L + 3 / 2) - N *= jnp.exp(jnp.sum(gammaln(lmn + 1 / 2))) + N *= np.exp(np.sum(gammaln(lmn + 1 / 2))) return N**-0.5 diff --git a/pyscf_ipu/experimental/structure.py b/pyscf_ipu/experimental/structure.py index 7977151..28c0316 100644 --- a/pyscf_ipu/experimental/structure.py +++ b/pyscf_ipu/experimental/structure.py @@ -20,6 +20,9 @@ def __post_init__(self): if not self.is_bohr: self.position = to_bohr(self.position) + # single atom case + self.position = np.atleast_2d(self.position) + @property def num_atoms(self) -> int: return len(self.atomic_number)