Skip to content
This repository has been archived by the owner on Sep 24, 2024. It is now read-only.

Commit

Permalink
switch to scalarlike
Browse files Browse the repository at this point in the history
  • Loading branch information
hatemhelal committed Apr 23, 2024
1 parent a2a0fae commit b4f3744
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions mess/hamiltonian.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jax.numpy as jnp
import jax.numpy.linalg as jnl
import optimistix as optx
from jaxtyping import Array
from jaxtyping import Array, ScalarLike

from mess.basis import Basis
from mess.integrals import eri_basis, kinetic_basis, nuclear_basis, overlap_basis
Expand Down Expand Up @@ -69,7 +69,7 @@ class HartreeFockExchange(eqx.Module):
def __init__(self, two_electron: TwoElectron):
self.two_electron = two_electron

def __call__(self, P: FloatNxN) -> float:
def __call__(self, P: FloatNxN) -> ScalarLike:
K = self.two_electron.exchange(P)
return -0.25 * jnp.sum(P * K)

Expand All @@ -82,7 +82,7 @@ def __init__(self, basis: Basis):
self.basis = basis
self.mesh = xcmesh_from_pyscf(basis.structure)

def __call__(self, P: FloatNxN) -> float:
def __call__(self, P: FloatNxN) -> ScalarLike:
rho = density(self.basis, self.mesh, P)
eps_xc = lda_exchange(rho) + lda_correlation_vwn(rho)
E_xc = jnp.einsum("i,i,i", self.mesh.weights, rho, eps_xc)
Expand All @@ -97,7 +97,7 @@ def __init__(self, basis: Basis):
self.basis = basis
self.mesh = xcmesh_from_pyscf(basis.structure)

def __call__(self, P: FloatNxN) -> float:
def __call__(self, P: FloatNxN) -> ScalarLike:
rho, grad_rho = density_and_grad(self.basis, self.mesh, P)
eps_xc = gga_exchange_pbe(rho, grad_rho) + gga_correlation_pbe(rho, grad_rho)
E_xc = jnp.einsum("i,i,i", self.mesh.weights, rho, eps_xc)
Expand All @@ -114,7 +114,7 @@ def __init__(self, basis: Basis, two_electron: TwoElectron):
self.mesh = xcmesh_from_pyscf(basis.structure)
self.hfx = HartreeFockExchange(two_electron)

def __call__(self, P: FloatNxN) -> float:
def __call__(self, P: FloatNxN) -> ScalarLike:
rho, grad_rho = density_and_grad(self.basis, self.mesh, P)
e = 0.75 * gga_exchange_pbe(rho, grad_rho) + gga_correlation_pbe(rho, grad_rho)
E_xc = jnp.einsum("i,i,i", self.mesh.weights, rho, e)
Expand All @@ -131,7 +131,7 @@ def __init__(self, basis: Basis, two_electron: TwoElectron):
self.mesh = xcmesh_from_pyscf(basis.structure)
self.hfx = HartreeFockExchange(two_electron)

def __call__(self, P: FloatNxN) -> float:
def __call__(self, P: FloatNxN) -> ScalarLike:
rho, grad_rho = density_and_grad(self.basis, self.mesh, P)
eps_x = 0.08 * lda_exchange(rho) + 0.72 * gga_exchange_b88(rho, grad_rho)
vwn_c = (1 - 0.81) * lda_correlation_vwn(rho)
Expand Down Expand Up @@ -177,7 +177,7 @@ def __init__(
self.two_electron = TwoElectron(basis, backend="pyscf")
self.xcfunc = build_xcfunc(xc_method, basis, self.two_electron)

def __call__(self, P: FloatNxN) -> float:
def __call__(self, P: FloatNxN) -> ScalarLike:
E_core = jnp.sum(self.H_core * P)
E_xc = self.xcfunc(P)
J = self.two_electron.coloumb(P)
Expand Down

0 comments on commit b4f3744

Please sign in to comment.