Skip to content

Commit

Permalink
Add mean_and_covariance_from_sigma_points
Browse files Browse the repository at this point in the history
  • Loading branch information
moeyensj committed Aug 13, 2023
1 parent aa31d1f commit e350c44
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 7 deletions.
28 changes: 28 additions & 0 deletions adam_core/coordinates/covariances.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,34 @@ def sample_covariance_sigma_points(
return sigma_points, W, W_cov


def mean_and_covariance_from_sigma_points(sigma_points, W, W_cov):
"""
Calculate a covariance matrix from sigma points and their corresponding weights.
Parameters
----------
sigma_points : `~numpy.ndarray` (2 * D + 1, D)
Sigma points drawn from the distribution.
W: `~numpy.ndarray` (2 * D + 1)
Weights of the sigma points.
W_cov: `~numpy.ndarray` (2 * D + 1)
Weights of the sigma points to reconstruct covariance matrix.
Returns
-------
mean : `~numpy.ndarray` (D)
Mean calculated from the sigma points and weights.
cov : `~numpy.ndarray` (D, D)
Covariance matrix calculated from the sigma points and weights.
"""
# Calculate the mean from the sigma points and weights
mean = np.dot(W, sigma_points)

# Calculate the covariance matrix from the sigma points and weights
cov = np.cov(sigma_points, aweights=W_cov, rowvar=False, bias=True)
return mean, cov


def transform_covariances_sampling(
coords: np.ndarray,
covariances: np.ndarray,
Expand Down
19 changes: 12 additions & 7 deletions adam_core/coordinates/tests/test_covariances.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import numpy as np

from ...utils.helpers.orbits import make_real_orbits
from ..covariances import CoordinateCovariances, sample_covariance_sigma_points
from ..covariances import (
CoordinateCovariances,
mean_and_covariance_from_sigma_points,
sample_covariance_sigma_points,
)


def test_sample_covariance_sigma_points():
def test_sigma_points():
# Get a sample of real orbits and test that sigma point sampling
# allows the state vector and its covariance to be reconstructed
orbits = make_real_orbits()
Expand All @@ -25,12 +29,13 @@ def test_sample_covariance_sigma_points():
# since beta = 0 internally
assert W_cov[0] == 0.0

# Reconstruct the mean and covariance
mean_sg = np.dot(W, samples)
# Reconstruct the mean and covariance and test that they match
# the original inputs to within 1e-14
mean_sg, covariance_sg = mean_and_covariance_from_sigma_points(
samples, W, W_cov
)
np.testing.assert_allclose(mean_sg, mean, rtol=0, atol=1e-14)

cov_sg = np.cov(samples, aweights=W_cov, rowvar=False, bias=True)
np.testing.assert_allclose(cov_sg, covariance, rtol=0, atol=1e-14)
np.testing.assert_allclose(covariance_sg, covariance, rtol=0, atol=1e-14)


def test_CoordinateCovariances_from_sigmas():
Expand Down

0 comments on commit e350c44

Please sign in to comment.