Skip to content
This repository has been archived by the owner on Sep 24, 2024. It is now read-only.

Commit

Permalink
add orthonormalise method to hamiltonian (fix jit failure)
Browse files Browse the repository at this point in the history
  • Loading branch information
hatemhelal committed Apr 22, 2024
1 parent 094493f commit a2a0fae
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions mess/hamiltonian.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def build_xcfunc(xc_method: str, basis: Basis, two_electron: TwoElectron) -> eqx


class Hamiltonian(eqx.Module):
X: FloatNxN
H_core: FloatNxN
basis: Basis
two_electron: TwoElectron
Expand All @@ -165,10 +166,13 @@ class Hamiltonian(eqx.Module):
def __init__(
self,
basis: Basis,
ont: OrthNormTransform = otransform_symmetric,
xc_method: str = "lda",
):
super().__init__()
self.basis = basis
S = overlap_basis(basis)
self.X = ont(S)
self.H_core = kinetic_basis(basis) + nuclear_basis(basis).sum(axis=0)
self.two_electron = TwoElectron(basis, backend="pyscf")
self.xcfunc = build_xcfunc(xc_method, basis, self.two_electron)
Expand All @@ -181,19 +185,22 @@ def __call__(self, P: FloatNxN) -> float:
E = E_core + E_xc + E_es
return E

def orthonormalise(self, Z: FloatNxN) -> FloatNxN:
C = self.X @ jnl.qr(Z).Q
return C


@jax.jit
def minimise(H: Hamiltonian, ont: OrthNormTransform = otransform_symmetric):
def minimise(H: Hamiltonian):
def f(Z, _):
C = X @ jnl.qr(Z).Q
C = H.orthonormalise(Z)
P = H.basis.density_matrix(C)
return H(P)

X = ont(overlap_basis(H.basis))
solver = optx.BFGS(rtol=1e-5, atol=1e-6)
Z = jnp.eye(H.basis.num_orbitals)
sol = optx.minimise(f, solver, Z)
C = X @ jnl.qr(sol.value).Q
C = H.orthonormalise(sol.value)
P = H.basis.density_matrix(C)
E_elec = H(P)
E_total = E_elec + nuclear_energy(H.basis.structure)
Expand Down

0 comments on commit a2a0fae

Please sign in to comment.