diff --git a/mess/hamiltonian.py b/mess/hamiltonian.py index 5455b4e..716ab10 100644 --- a/mess/hamiltonian.py +++ b/mess/hamiltonian.py @@ -1,5 +1,6 @@ # Copyright (c) 2024 Graphcore Ltd. All rights reserved. -from typing import Tuple, Literal, get_args +from typing import Literal, Optional, Tuple, get_args + import equinox as eqx import jax import jax.numpy as jnp @@ -15,15 +16,14 @@ from mess.structure import nuclear_energy from mess.types import FloatNxN, OrthNormTransform from mess.xcfunctional import ( - lda_correlation_vwn, - lda_exchange, + gga_correlation_lyp, gga_correlation_pbe, - gga_exchange_pbe, gga_exchange_b88, - gga_correlation_lyp, + gga_exchange_pbe, + lda_correlation_vwn, + lda_exchange, ) - xcstr = Literal["lda", "pbe", "pbe0", "b3lyp", "hfx"] @@ -146,8 +146,13 @@ def __call__(self, P: FloatNxN) -> ScalarLike: def build_xcfunc( - xc_method: xcstr, basis: Basis, two_electron: TwoElectron + xc_method: xcstr, basis: Basis, two_electron: Optional[TwoElectron] = None ) -> eqx.Module: + if two_electron is None and xc_method in ("pbe0", "b3lyp"): + raise ValueError( + f"Hybrid functional {xc_method} requires providing TwoElectron integrals" + ) + match xc_method: case "lda": return LDA(basis)