diff --git a/mess/orthnorm.py b/mess/orthnorm.py index 07d5e0e..f813edd 100644 --- a/mess/orthnorm.py +++ b/mess/orthnorm.py @@ -1,9 +1,4 @@ # Copyright (c) 2024 Graphcore Ltd. All rights reserved. -import jax.numpy as jnp -import jax.numpy.linalg as jnl - -from mess.types import FloatNxN - """Orthonormal transformation. Evaluates the transformation matrix :math:`X` that satisfies @@ -16,6 +11,11 @@ This module implements a few commonly used orthonormalisation transforms. """ +import jax.numpy as jnp +import jax.numpy.linalg as jnl + +from mess.types import FloatNxN + def canonical(S: FloatNxN) -> FloatNxN: """Canonical orthonormal transformation @@ -60,7 +60,7 @@ def cholesky(S: FloatNxN) -> FloatNxN: .. math:: \mathbf{X} = (\mathbf{L}^{-1})^T - where :math:`\mathbf{L}` is the lower triangular matrix the satisfies the Cholesky + where :math:`\mathbf{L}` is the lower triangular matrix that satisfies the Cholesky decomposition of the overlap matrix :math:`\mathbf{S}`. Args: