From ce026e9047cddc5e53e49c15b9514ae34707f60d Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Sun, 15 Oct 2023 15:54:09 +0000 Subject: [PATCH] use jnp in primitive product method to support jit --- pyscf_ipu/experimental/primitive.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyscf_ipu/experimental/primitive.py b/pyscf_ipu/experimental/primitive.py index 3451113..1113b5a 100644 --- a/pyscf_ipu/experimental/primitive.py +++ b/pyscf_ipu/experimental/primitive.py @@ -41,7 +41,7 @@ def product(a: Primitive, b: Primitive) -> Primitive: lmn = a.lmn + b.lmn c = a.norm * b.norm Rab = a.center - b.center - c *= np.exp(-a.alpha * b.alpha / alpha * np.inner(Rab, Rab)) + c *= jnp.exp(-a.alpha * b.alpha / alpha * jnp.inner(Rab, Rab)) return Primitive(center=center, alpha=alpha, lmn=lmn, norm=c)