From e350c44c25d65ef236c7205398eaeb69eb52ec0c Mon Sep 17 00:00:00 2001 From: Joachim Moeyens Date: Sun, 13 Aug 2023 08:44:54 -0700 Subject: [PATCH] Add mean_and_covariance_from_sigma_points --- adam_core/coordinates/covariances.py | 28 +++++++++++++++++++ .../coordinates/tests/test_covariances.py | 19 ++++++++----- 2 files changed, 40 insertions(+), 7 deletions(-) diff --git a/adam_core/coordinates/covariances.py b/adam_core/coordinates/covariances.py index 8c89e812..58d81b65 100644 --- a/adam_core/coordinates/covariances.py +++ b/adam_core/coordinates/covariances.py @@ -318,6 +318,34 @@ def sample_covariance_sigma_points( return sigma_points, W, W_cov +def mean_and_covariance_from_sigma_points(sigma_points, W, W_cov): + """ + Calculate a covariance matrix from sigma points and their corresponding weights. + + Parameters + ---------- + sigma_points : `~numpy.ndarray` (2 * D + 1, D) + Sigma points drawn from the distribution. + W: `~numpy.ndarray` (2 * D + 1) + Weights of the sigma points. + W_cov: `~numpy.ndarray` (2 * D + 1) + Weights of the sigma points to reconstruct covariance matrix. + + Returns + ------- + mean : `~numpy.ndarray` (D) + Mean calculated from the sigma points and weights. + cov : `~numpy.ndarray` (D, D) + Covariance matrix calculated from the sigma points and weights. + """ + # Calculate the mean from the sigma points and weights + mean = np.dot(W, sigma_points) + + # Calculate the covariance matrix from the sigma points and weights + cov = np.cov(sigma_points, aweights=W_cov, rowvar=False, bias=True) + return mean, cov + + def transform_covariances_sampling( coords: np.ndarray, covariances: np.ndarray, diff --git a/adam_core/coordinates/tests/test_covariances.py b/adam_core/coordinates/tests/test_covariances.py index 455a9ac5..140232d0 100644 --- a/adam_core/coordinates/tests/test_covariances.py +++ b/adam_core/coordinates/tests/test_covariances.py @@ -1,10 +1,14 @@ import numpy as np from ...utils.helpers.orbits import make_real_orbits -from ..covariances import CoordinateCovariances, sample_covariance_sigma_points +from ..covariances import ( + CoordinateCovariances, + mean_and_covariance_from_sigma_points, + sample_covariance_sigma_points, +) -def test_sample_covariance_sigma_points(): +def test_sigma_points(): # Get a sample of real orbits and test that sigma point sampling # allows the state vector and its covariance to be reconstructed orbits = make_real_orbits() @@ -25,12 +29,13 @@ def test_sample_covariance_sigma_points(): # since beta = 0 internally assert W_cov[0] == 0.0 - # Reconstruct the mean and covariance - mean_sg = np.dot(W, samples) + # Reconstruct the mean and covariance and test that they match + # the original inputs to within 1e-14 + mean_sg, covariance_sg = mean_and_covariance_from_sigma_points( + samples, W, W_cov + ) np.testing.assert_allclose(mean_sg, mean, rtol=0, atol=1e-14) - - cov_sg = np.cov(samples, aweights=W_cov, rowvar=False, bias=True) - np.testing.assert_allclose(cov_sg, covariance, rtol=0, atol=1e-14) + np.testing.assert_allclose(covariance_sg, covariance, rtol=0, atol=1e-14) def test_CoordinateCovariances_from_sigmas():