Skip to content

Commit

Permalink
Added faster method for inverting symmetric matrices
Browse files Browse the repository at this point in the history
Fixes #55.
The implementation was taken from here:
https://stackoverflow.com/a/58719188
with some minor improvements.
It also raises a `LinAlgError` if the matrix is singular.
  • Loading branch information
JCGoran committed Apr 10, 2023
1 parent b3912c8 commit f011dec
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 3 deletions.
13 changes: 11 additions & 2 deletions fitk/derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@

# first party imports
from fitk.tensors import FisherMatrix
from fitk.utilities import P, ValidationError, find_diff_weights, is_iterable
from fitk.utilities import (
P,
ValidationError,
fast_positive_definite_inverse,
find_diff_weights,
is_iterable,
)


def _zero_out(array, threshold: float):
Expand Down Expand Up @@ -570,6 +576,9 @@ def fisher_matrix(
ValueError
if the argument `external_covariance` is not a square matrix
LinAlgError
if the covariance matrix is singular
Notes
-----
The `order` parameter is ignored if passed to `D`.
Expand Down Expand Up @@ -636,7 +645,7 @@ def fisher_matrix(
*zip(names, fiducials), **kwargs_covariance
)

inverse_covariance_matrix = np.linalg.inv(covariance_matrix)
inverse_covariance_matrix = fast_positive_definite_inverse(covariance_matrix)

covariance_shape = np.shape(inverse_covariance_matrix)

Expand Down
8 changes: 7 additions & 1 deletion fitk/tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
MismatchingValuesError,
P,
ParameterNotFoundError,
fast_positive_definite_inverse,
get_index_of_other_array,
is_iterable,
is_positive_semidefinite,
Expand Down Expand Up @@ -1413,6 +1414,11 @@ def inverse(self):
array_like : float
the covariance matrix as a numpy array
Raises
------
LinAlgError
if the Fisher matrix is singular
Examples
--------
>>> fm = FisherMatrix(np.diag([1, 2, 5]))
Expand All @@ -1421,7 +1427,7 @@ def inverse(self):
[0. , 0.5, 0. ],
[0. , 0. , 0.2]])
"""
return np.linalg.inv(self.values)
return fast_positive_definite_inverse(self.values)

def determinant(self):
"""
Expand Down
28 changes: 28 additions & 0 deletions fitk/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
import math
from collections.abc import Collection, Sequence
from dataclasses import dataclass
from functools import lru_cache
from math import factorial
from typing import Optional, Union

# third party imports
import numpy as np
from scipy.linalg.lapack import dpotrf, dpotri # pylint: disable=no-name-in-module


@dataclass
Expand Down Expand Up @@ -396,3 +398,29 @@ def math_mode(
raise TypeError(err) from err

return [f"${_}$" for _ in arg]


@lru_cache(maxsize=None)
def _get_inds(size: int):
return np.tri(size, k=-1, dtype=bool)


def upper_triangular_to_symmetric(upper_triangular):
r"""Convert an upper triangular matrix to a full, symmetric matrix."""
size = upper_triangular.shape[0]
inds = _get_inds(size)
upper_triangular[inds] = upper_triangular.T[inds]


def fast_positive_definite_inverse(matrix):
r"""Compute the fast inverse of a positive-definite symmetric NxN matrix."""
cholesky, info = dpotrf(matrix)
if info != 0:
raise np.linalg.LinAlgError(f"dpotrf failed on input {matrix}")

inv, info = dpotri(cholesky)
if info != 0:
raise np.linalg.LinAlgError(f"dpotri failed on input {cholesky}")

upper_triangular_to_symmetric(inv)
return inv

0 comments on commit f011dec

Please sign in to comment.