Skip to content
This repository has been archived by the owner on Sep 24, 2024. It is now read-only.

Commit

Permalink
separate orthonormal transformations
Browse files Browse the repository at this point in the history
  • Loading branch information
hatemhelal committed Apr 29, 2024
1 parent 44759b1 commit 5eaf629
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 22 deletions.
4 changes: 2 additions & 2 deletions mess/hamiltonian.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from mess.integrals import eri_basis, kinetic_basis, nuclear_basis, overlap_basis
from mess.interop import to_pyscf
from mess.mesh import Mesh, density, density_and_grad, xcmesh_from_pyscf
from mess.scf import otransform_symmetric
from mess.orthnorm import canonical
from mess.structure import nuclear_energy
from mess.types import FloatNxN, OrthNormTransform
from mess.xcfunctional import (
Expand Down Expand Up @@ -182,7 +182,7 @@ class Hamiltonian(eqx.Module):
def __init__(
self,
basis: Basis,
ont: OrthNormTransform = otransform_symmetric,
ont: OrthNormTransform = canonical,
xc_method: xcstr = "lda",
):
super().__init__()
Expand Down
73 changes: 73 additions & 0 deletions mess/orthnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
import jax.numpy as jnp
import jax.numpy.linalg as jnl

from mess.types import FloatNxN

"""Orthonormal transformation.
Evaluates the transformation matrix :math:`X` that satisfies
.. math:: \mathbf{X}^T \mathbf{S} \mathbf{X} = \mathbb{I}
where :math:`\mathbf{S}` is the overlap matrix of the non-orthonormal basis and
:math:`\mathbb{I}` is the identity matrix.
This module implements a few commonly used orthonormalisation transforms.
"""


def canonical(S: FloatNxN) -> FloatNxN:
"""Canonical orthonormal transformation
.. math:: \mathbf{X} = \mathbf{U} \mathbf{s}^{-1/2}
where :math:`\mathbf{U}` and :math:`\mathbf{s}` are the eigenvectors and
eigenvalues of the overlap matrix :math:`\mathbf{S}`.
Args:
S (FloatNxN): overlap matrix for the non-orthonormal basis.
Returns:
FloatNxN: canonical orthonormal transformation matrix
"""
s, U = jnl.eigh(S)
s = jnp.diag(jnp.power(s, -0.5))
return U @ s


def symmetric(S: FloatNxN) -> FloatNxN:
"""Symmetric orthonormal transformation
.. math:: \mathbf{X} = \mathbf{U} \mathbf{s}^{-1/2} \mathbf{U}^T
where :math:`\mathbf{U}` and :math:`\mathbf{s}` are the eigenvectors and
eigenvalues of the overlap matrix :math:`\mathbf{S}`.
Args:
S (FloatNxN): overlap matrix for the non-orthonormal basis.
Returns:
FloatNxN: symmetric orthonormal transformation matrix
"""
s, U = jnl.eigh(S)
s = jnp.diag(jnp.power(s, -0.5))
return U @ s @ U.T


def cholesky(S: FloatNxN) -> FloatNxN:
"""Cholesky orthonormal transformation
.. math:: \mathbf{X} = (\mathbf{L}^{-1})^T
where :math:`\mathbf{L}` is the lower triangular matrix the satisfies the Cholesky
decomposition of the overlap matrix :math:`\mathbf{S}`.
Args:
S (FloatNxN): overlap matrix for the non-orthonormal basis.
Returns:
FloatNxN: cholesky orthonormal transformation matrix
"""
L = jnl.cholesky(S)
return jnl.inv(L).T
23 changes: 3 additions & 20 deletions mess/scf.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,18 @@
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
from typing import Callable

import jax.numpy as jnp
import jax.numpy.linalg as jnl
from jax.lax import while_loop

from mess.basis import Basis
from mess.integrals import eri_basis, kinetic_basis, nuclear_basis, overlap_basis
from mess.structure import nuclear_energy


def otransform_canonical(S):
s, U = jnl.eigh(S)
s = jnp.diag(jnp.power(s, -0.5))
return U @ s


def otransform_symmetric(S):
s, U = jnl.eigh(S)
s = jnp.diag(jnp.power(s, -0.5))
return U @ s @ U.T


def otransform_cholesky(S):
L = jnl.cholesky(S)
return jnl.inv(L).T
from mess.orthnorm import cholesky
from mess.types import OrthNormTransform


def scf(
basis: Basis,
otransform: Callable = otransform_cholesky,
otransform: OrthNormTransform = cholesky,
max_iters: int = 32,
tolerance: float = 1e-4,
):
Expand Down

0 comments on commit 5eaf629

Please sign in to comment.