diff --git a/mess/hamiltonian.py b/mess/hamiltonian.py index ae644b5..7be73f5 100644 --- a/mess/hamiltonian.py +++ b/mess/hamiltonian.py @@ -12,7 +12,7 @@ from mess.integrals import eri_basis, kinetic_basis, nuclear_basis, overlap_basis from mess.interop import to_pyscf from mess.mesh import Mesh, density, density_and_grad, xcmesh_from_pyscf -from mess.orthnorm import canonical +from mess.orthnorm import symmetric from mess.structure import nuclear_energy from mess.types import FloatNxN, OrthNormTransform from mess.xcfunctional import ( @@ -25,6 +25,40 @@ ) xcstr = Literal["lda", "pbe", "pbe0", "b3lyp", "hfx"] +IntegralBackend = Literal["mess", "pyscf_cart", "pyscf_sph"] + + +class OneElectron(eqx.Module): + overlap: FloatNxN + kinetic: FloatNxN + nuclear: FloatNxN + + def __init__(self, basis: Basis, backend: IntegralBackend = "mess"): + """_summary_ + + Args: + basis (Basis): _description_ + backend (IntegralBackend, optional): _description_. Defaults to "mess". + + Raises: + ValueError: _description_ + ValueError: _description_ + + Returns: + _type_: _description_ + """ + if backend == "mess": + self.overlap = overlap_basis(basis) + self.kinetic = kinetic_basis(basis) + self.nuclear = nuclear_basis(basis).sum(axis=0) + elif backend.startswith("pyscf_"): + mol = to_pyscf(basis.structure, basis.basis_name) + kind = backend.split("_")[1] + S = jnp.array(mol.intor(f"int1e_ovlp_{kind}")) + N = 1 / jnp.sqrt(jnp.diagonal(S)) + self.overlap = N[:, jnp.newaxis] * N[jnp.newaxis, :] * S + self.kinetic = jnp.array(mol.intor(f"int1e_kin_{kind}")) + self.nuclear = jnp.array(mol.intor(f"int1e_nuc_{kind}")) class TwoElectron(eqx.Module): @@ -40,9 +74,10 @@ def __init__(self, basis: Basis, backend: str = "mess"): super().__init__() if backend == "mess": self.eri = eri_basis(basis) - elif backend == "pyscf": + elif backend.startswith("pyscf_"): mol = to_pyscf(basis.structure, basis.basis_name) - self.eri = jnp.array(mol.intor("int2e_cart", aosym="s1")) + kind = backend.split("_")[1] + self.eri = jnp.array(mol.intor(f"int2e_{kind}", aosym="s1")) def coloumb(self, P: FloatNxN) -> FloatNxN: """Build the Coloumb matrix (classical electrostatic) from the density matrix. @@ -182,15 +217,17 @@ class Hamiltonian(eqx.Module): def __init__( self, basis: Basis, - ont: OrthNormTransform = canonical, + ont: OrthNormTransform = symmetric, xc_method: xcstr = "lda", + backend: IntegralBackend = "pyscf_cart", ): super().__init__() self.basis = basis - S = overlap_basis(basis) + one_elec = OneElectron(basis, backend=backend) + S = one_elec.overlap self.X = ont(S) - self.H_core = kinetic_basis(basis) + nuclear_basis(basis).sum(axis=0) - self.two_electron = TwoElectron(basis, backend="pyscf") + self.H_core = one_elec.kinetic + one_elec.nuclear + self.two_electron = TwoElectron(basis, backend=backend) self.xcfunc = build_xcfunc(xc_method, basis, self.two_electron) def __call__(self, P: FloatNxN) -> ScalarLike: