diff --git a/mess/hamiltonian.py b/mess/hamiltonian.py index 523dfb7..721a1b3 100644 --- a/mess/hamiltonian.py +++ b/mess/hamiltonian.py @@ -66,7 +66,7 @@ def exchange(self, P: FloatNxN) -> FloatNxN: class HartreeFockExchange(eqx.Module): two_electron: TwoElectron - def __init__(self, two_electron): + def __init__(self, two_electron: TwoElectron): self.two_electron = two_electron def __call__(self, P: FloatNxN) -> float: @@ -109,7 +109,7 @@ class PBE0(eqx.Module): mesh: Mesh hfx: HartreeFockExchange - def __init__(self, basis: Basis, two_electron): + def __init__(self, basis: Basis, two_electron: TwoElectron): self.basis = basis self.mesh = xcmesh_from_pyscf(basis.structure) self.hfx = HartreeFockExchange(two_electron) @@ -126,7 +126,7 @@ class B3LYP(eqx.Module): mesh: Mesh hfx: HartreeFockExchange - def __init__(self, basis: Basis, two_electron): + def __init__(self, basis: Basis, two_electron: TwoElectron): self.basis = basis self.mesh = xcmesh_from_pyscf(basis.structure) self.hfx = HartreeFockExchange(two_electron) @@ -141,7 +141,7 @@ def __call__(self, P: FloatNxN) -> float: return E_xc + 0.2 * self.hfx(P) -def build_xcfunc(xc_method, basis, two_electron): +def build_xcfunc(xc_method: str, basis: Basis, two_electron: TwoElectron) -> eqx.Module: if xc_method == "lda": return LDA(basis) if xc_method == "pbe":