Skip to content

Commit

Permalink
Add VariantOrbits.collapse -- collapses variants orbits at different …
Browse files Browse the repository at this point in the history
…epochs and recomputes the covariance matrix
  • Loading branch information
moeyensj committed Aug 18, 2023
1 parent 5d04641 commit ceaa7c2
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 0 deletions.
36 changes: 36 additions & 0 deletions adam_core/orbits/tests/test_variants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import numpy as np

from ...utils.helpers.orbits import make_real_orbits
from ..variants import VariantOrbits


def test_VariantOrbits():

# Get a sample of real orbits
orbits = make_real_orbits(10)

# Create a variant orbits object (expands the covariance matrix)
# around the mean state
variant_orbits = VariantOrbits.create(orbits)

# For these 10 orbits this will select sigma-points so lets
# check that the number of sigma-points is correct
assert len(variant_orbits) == len(orbits) * 13

# Now lets collapse the sigma-points back and see if we can reconstruct
# the input covairance matrix
collapsed_orbits = variant_orbits.collapse(orbits)

# Check that the covariance matrices are close
np.testing.assert_allclose(
collapsed_orbits.coordinates.covariance.to_matrix(),
orbits.coordinates.covariance.to_matrix(),
rtol=0,
atol=1e-14,
)

# Check that the orbit ids are the same
np.testing.assert_equal(
collapsed_orbits.orbit_id.to_numpy(zero_copy_only=False),
orbits.orbit_id.to_numpy(zero_copy_only=False),
)
80 changes: 80 additions & 0 deletions adam_core/orbits/variants.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import uuid
from typing import Literal

import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
import quivr as qv

from ..coordinates.cartesian import CartesianCoordinates
from ..coordinates.covariances import (
CoordinateCovariances,
mean_and_covariance_from_weighted_samples,
)
from ..coordinates.variants import create_coordinate_variants
from .orbits import Orbits

Expand Down Expand Up @@ -95,3 +101,77 @@ def link_to_orbits(
"jd2": self.coordinates.time.jd2,
},
)

def collapse(self, orbits: Orbits) -> Orbits:
"""
Collapse the variants and recalculate the covariance matrix for each
each orbit at each epoch. The mean state is taken from the orbits class and
is not calculate from the variants.
Parameters
----------
orbits : `~adam_core.orbits.orbits.Orbits`
Orbits from which the variants were generated.
Returns
-------
collapsed_orbits : `~adam_core.orbits.orbits.Orbits`
The collapsed orbits.
"""
link = self.link_to_orbits(orbits)

# Iterate over the variants and calculate the mean state and covariance matrix
# for each orbit at each epoch then create a new orbit with the calculated covariance matrix
orbits_list = []
for key, orbit, variants in link.iterate():
key = key.as_py()

assert len(orbit) == 1

samples = variants.coordinates.values
mean_state, covariance = mean_and_covariance_from_weighted_samples(
samples, variants.weights.to_numpy(), variants.weights_cov.to_numpy()
)

orbit_collapsed = Orbits.from_kwargs(
orbit_id=orbit.orbit_id,
object_id=orbit.object_id,
coordinates=CartesianCoordinates.from_kwargs(
time=orbit.coordinates.time,
x=orbit.coordinates.x,
y=orbit.coordinates.y,
z=orbit.coordinates.z,
vx=orbit.coordinates.vx,
vy=orbit.coordinates.vy,
vz=orbit.coordinates.vz,
covariance=CoordinateCovariances.from_matrix(covariance),
frame=orbit.coordinates.frame,
origin=orbit.coordinates.origin,
),
)

orbits_list.append(orbit_collapsed)

orbits_collapsed = qv.concatenate(orbits_list)

# Array of indices into the collapsed orbits
orbits_idx = pa.array(np.arange(0, len(orbits_collapsed)))

# Make a list of arrays that will be used to sort the orbits
orbits_idx_sorted_list = []

# Loop over input orbits and figure out where in the collapsed orbits they occur
# There has to be an easier or better way to do this?
for orbit in orbits:
mask_orbit_id = pc.equal(orbits_collapsed.orbit_id, orbit.orbit_id[0])
mask_jd1 = pc.equal(
orbits_collapsed.coordinates.time.jd1, orbit.coordinates.time.jd1[0]
)
mask_jd2 = pc.equal(
orbits_collapsed.coordinates.time.jd2, orbit.coordinates.time.jd2[0]
)
mask = pc.and_(mask_orbit_id, pc.and_(mask_jd1, mask_jd2))
orbits_idx_sorted_list.append(orbits_idx.filter(mask))

orbits_idx_sorted = pa.concat_arrays(orbits_idx_sorted_list)
return orbits_collapsed.take(orbits_idx_sorted)

0 comments on commit ceaa7c2

Please sign in to comment.