From b4f37443688dc63832e185c5247301e1a325a746 Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Tue, 23 Apr 2024 15:17:11 +0000 Subject: [PATCH] switch to scalarlike --- mess/hamiltonian.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/mess/hamiltonian.py b/mess/hamiltonian.py index 9c20f35..f7affc3 100644 --- a/mess/hamiltonian.py +++ b/mess/hamiltonian.py @@ -4,7 +4,7 @@ import jax.numpy as jnp import jax.numpy.linalg as jnl import optimistix as optx -from jaxtyping import Array +from jaxtyping import Array, ScalarLike from mess.basis import Basis from mess.integrals import eri_basis, kinetic_basis, nuclear_basis, overlap_basis @@ -69,7 +69,7 @@ class HartreeFockExchange(eqx.Module): def __init__(self, two_electron: TwoElectron): self.two_electron = two_electron - def __call__(self, P: FloatNxN) -> float: + def __call__(self, P: FloatNxN) -> ScalarLike: K = self.two_electron.exchange(P) return -0.25 * jnp.sum(P * K) @@ -82,7 +82,7 @@ def __init__(self, basis: Basis): self.basis = basis self.mesh = xcmesh_from_pyscf(basis.structure) - def __call__(self, P: FloatNxN) -> float: + def __call__(self, P: FloatNxN) -> ScalarLike: rho = density(self.basis, self.mesh, P) eps_xc = lda_exchange(rho) + lda_correlation_vwn(rho) E_xc = jnp.einsum("i,i,i", self.mesh.weights, rho, eps_xc) @@ -97,7 +97,7 @@ def __init__(self, basis: Basis): self.basis = basis self.mesh = xcmesh_from_pyscf(basis.structure) - def __call__(self, P: FloatNxN) -> float: + def __call__(self, P: FloatNxN) -> ScalarLike: rho, grad_rho = density_and_grad(self.basis, self.mesh, P) eps_xc = gga_exchange_pbe(rho, grad_rho) + gga_correlation_pbe(rho, grad_rho) E_xc = jnp.einsum("i,i,i", self.mesh.weights, rho, eps_xc) @@ -114,7 +114,7 @@ def __init__(self, basis: Basis, two_electron: TwoElectron): self.mesh = xcmesh_from_pyscf(basis.structure) self.hfx = HartreeFockExchange(two_electron) - def __call__(self, P: FloatNxN) -> float: + def __call__(self, P: FloatNxN) -> ScalarLike: rho, grad_rho = density_and_grad(self.basis, self.mesh, P) e = 0.75 * gga_exchange_pbe(rho, grad_rho) + gga_correlation_pbe(rho, grad_rho) E_xc = jnp.einsum("i,i,i", self.mesh.weights, rho, e) @@ -131,7 +131,7 @@ def __init__(self, basis: Basis, two_electron: TwoElectron): self.mesh = xcmesh_from_pyscf(basis.structure) self.hfx = HartreeFockExchange(two_electron) - def __call__(self, P: FloatNxN) -> float: + def __call__(self, P: FloatNxN) -> ScalarLike: rho, grad_rho = density_and_grad(self.basis, self.mesh, P) eps_x = 0.08 * lda_exchange(rho) + 0.72 * gga_exchange_b88(rho, grad_rho) vwn_c = (1 - 0.81) * lda_correlation_vwn(rho) @@ -177,7 +177,7 @@ def __init__( self.two_electron = TwoElectron(basis, backend="pyscf") self.xcfunc = build_xcfunc(xc_method, basis, self.two_electron) - def __call__(self, P: FloatNxN) -> float: + def __call__(self, P: FloatNxN) -> ScalarLike: E_core = jnp.sum(self.H_core * P) E_xc = self.xcfunc(P) J = self.two_electron.coloumb(P)