Skip to content

Commit

Permalink
Update basisset to use numpy over jax.numpy (#126)
Browse files Browse the repository at this point in the history
* Update basisset to use numpy over jax.numpy

* change lmn to use signed int32

* use jnp in primitive product method to support jit
  • Loading branch information
hatemhelal authored Oct 17, 2023
1 parent c21b92b commit 0e986f6
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 9 deletions.
7 changes: 4 additions & 3 deletions pyscf_ipu/experimental/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import chex
import jax.numpy as jnp
import numpy as np

from .orbital import Orbital
from .structure import Structure
Expand Down Expand Up @@ -61,9 +62,9 @@ def basisset(structure: Structure, basis_name: str = "sto-3g"):
for lmn in LMN_MAP[s["angular_momentum"][0]]:
ao = Orbital.from_bse(
center=center,
alphas=jnp.array(s["exponents"], dtype=float),
lmn=jnp.array(lmn, dtype=jnp.int32),
coefficients=jnp.array(s["coefficients"], dtype=float),
alphas=np.array(s["exponents"], dtype=np.float32),
lmn=np.array(lmn, dtype=np.int32),
coefficients=np.array(s["coefficients"], dtype=np.float32),
)
orbitals.append(ao)

Expand Down
13 changes: 7 additions & 6 deletions pyscf_ipu/experimental/primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@

import chex
import jax.numpy as jnp
from jax.scipy.special import gammaln
import numpy as np
from scipy.special import gammaln

from .types import Float3, FloatN, FloatNx3, Int3


@chex.dataclass
class Primitive:
center: Float3 = jnp.zeros(3, dtype=jnp.float32)
center: Float3 = np.zeros(3, dtype=np.float32)
alpha: float = 1.0
lmn: Int3 = jnp.zeros(3, dtype=jnp.int32)
lmn: Int3 = np.zeros(3, dtype=np.int32)
norm: Optional[float] = None

def __post_init__(self):
Expand All @@ -21,16 +22,16 @@ def __post_init__(self):

@property
def angular_momentum(self) -> int:
return jnp.sum(self.lmn)
return np.sum(self.lmn)

def __call__(self, pos: FloatNx3) -> FloatN:
return eval_primitive(self, pos)


def normalize(lmn: Int3, alpha: float) -> float:
L = jnp.sum(lmn)
L = np.sum(lmn)
N = ((1 / 2) / alpha) ** (L + 3 / 2)
N *= jnp.exp(jnp.sum(gammaln(lmn + 1 / 2)))
N *= np.exp(np.sum(gammaln(lmn + 1 / 2)))
return N**-0.5


Expand Down
3 changes: 3 additions & 0 deletions pyscf_ipu/experimental/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ def __post_init__(self):
if not self.is_bohr:
self.position = to_bohr(self.position)

# single atom case
self.position = np.atleast_2d(self.position)

@property
def num_atoms(self) -> int:
return len(self.atomic_number)
Expand Down

0 comments on commit 0e986f6

Please sign in to comment.