From 45709f27afb78ca061eeaaa76528d886cbd7f866 Mon Sep 17 00:00:00 2001 From: Joachim Moeyens Date: Tue, 22 Aug 2023 15:35:57 -0700 Subject: [PATCH] Define CoordinateVariants table inside create_coordinate_variants Motivated by @spenczar: "Or if you're insane, you could define the class inside the function. I think this really would work!" (https://github.com/B612-Asteroid-Institute/adam_core/pull/41#discussion_r1296516532) Co-authored-by: Spencer Nelson --- adam_core/coordinates/variants.py | 146 ++++++++---------------------- adam_core/orbits/variants.py | 12 +-- 2 files changed, 45 insertions(+), 113 deletions(-) diff --git a/adam_core/coordinates/variants.py b/adam_core/coordinates/variants.py index 8f70c63b..d52481e9 100644 --- a/adam_core/coordinates/variants.py +++ b/adam_core/coordinates/variants.py @@ -1,4 +1,4 @@ -from typing import Literal, Tuple, Union +from typing import Literal, Union import numpy as np import quivr as qv @@ -29,7 +29,7 @@ def create_coordinate_variants( alpha: float = 1, beta: float = 0, kappa: float = 0, -) -> Tuple[np.ndarray, np.ndarray, np.ndarray, CoordinateType]: +) -> "CoordinateVariants": """ Sample and create variants for the given coordinates by sampling the covariance matrices. There are three supported methods: @@ -69,17 +69,8 @@ def create_coordinate_variants( Returns ------- - idx : '~numpy.ndarray' - The index of the coordinate that each sample belongs to. - weights : '~numpy.ndarray' - Weights of each sample. - cov_weights : '~numpy.ndarray' - Weights of the samples to reconstruct covariance matrix. - samples : {'~adam_core.coordinates.cartesian.CartesianCoordinates', - '~adam_core.coordinates.keplerian.KeplerianCoordinates', - '~adam_core.coordinates.cometary.CometaryCoordinates', - '~adam_core.coordinates.spherical.SphericalCoordinates'} - The samples drawn from the coordinate covariance matrices. + variants : '~adam_core.coordinates.variants.VariantCoordinates' + The variant coordinates. Raises ------ @@ -87,18 +78,29 @@ def create_coordinate_variants( If the covariance matrices are all undefined. If the input coordinates are not supported. """ - idx_list = [] - samples_list = [] - weights_list = [] - cov_weights_list = [] - origins_list = [] - times_list = [] - if coordinates.covariance.is_all_nan(): raise ValueError( "Cannot sample coordinate covariances when covariances are all undefined." ) + class CoordinateVariants(qv.Table): + index = qv.Int64Column() + sample = coordinates.as_column() + weight = qv.Float64Column() + weight_cov = qv.Float64Column() + + if isinstance(coordinates, CartesianCoordinates): + dimensions = ["x", "y", "z", "vx", "vy", "vz"] + elif isinstance(coordinates, SphericalCoordinates): + dimensions = ["rho", "lon", "lat", "vrho", "vlon", "vlat"] + elif isinstance(coordinates, KeplerianCoordinates): + dimensions = ["a", "e", "i", "raan", "ap", "M"] + elif isinstance(coordinates, CometaryCoordinates): + dimensions = ["q", "e", "i", "raan", "ap", "tp"] + else: + raise ValueError(f"Unsupported coordinate type: {type(coordinates)}") + + variants_list = [] for i, coordinate_i in enumerate(coordinates): mean = coordinate_i.values[0] @@ -146,92 +148,22 @@ def create_coordinate_variants( else: raise ValueError(f"Unknown coordinate covariance sampling method: {method}") - origins_list += [coordinate_i.origin for i in range(len(samples))] - times_list += [coordinate_i.time for i in range(len(samples))] - samples_list.append(samples) - weights_list.append(W) - cov_weights_list.append(W_cov) - idx_list.append(np.full(len(samples), i)) - - samples = np.concatenate(samples_list) - idx = np.concatenate(idx_list) - weights = np.concatenate(weights_list) - cov_weights = np.concatenate(cov_weights_list) - origins = qv.concatenate(origins_list) - times = qv.concatenate(times_list) - - if isinstance(coordinates, CartesianCoordinates): - return ( - idx, - weights, - cov_weights, - CartesianCoordinates.from_kwargs( - x=samples[:, 0], - y=samples[:, 1], - z=samples[:, 2], - vx=samples[:, 3], - vy=samples[:, 4], - vz=samples[:, 5], - time=times, - covariance=None, - origin=origins, - frame=coordinates.frame, - ), - ) - elif isinstance(coordinates, SphericalCoordinates): - return ( - idx, - weights, - cov_weights, - SphericalCoordinates.from_kwargs( - rho=samples[:, 0], - lon=samples[:, 1], - lat=samples[:, 2], - vrho=samples[:, 3], - vlon=samples[:, 4], - vlat=samples[:, 5], - time=times, - covariance=None, - origin=origins, - frame=coordinates.frame, - ), - ) - elif isinstance(coordinates, KeplerianCoordinates): - return ( - idx, - weights, - cov_weights, - KeplerianCoordinates.from_kwargs( - a=samples[:, 0], - e=samples[:, 1], - i=samples[:, 2], - raan=samples[:, 3], - ap=samples[:, 4], - M=samples[:, 5], - time=times, - covariance=None, - origin=origins, - frame=coordinates.frame, - ), - ) - elif isinstance(coordinates, CometaryCoordinates): - return ( - idx, - weights, - cov_weights, - CometaryCoordinates.from_kwargs( - q=samples[:, 0], - e=samples[:, 1], - i=samples[:, 2], - raan=samples[:, 3], - ap=samples[:, 4], - tp=samples[:, 5], - time=times, - covariance=None, - origin=origins, - frame=coordinates.frame, - ), + variants_list.append( + CoordinateVariants.from_kwargs( + index=np.full(len(samples), i), + sample=coordinates.from_kwargs( + origin=qv.concatenate( + [coordinate_i.origin for i in range(len(samples))] + ), + time=qv.concatenate( + [coordinate_i.time for i in range(len(samples))] + ), + frame=coordinate_i.frame, + **{dim: samples[:, i] for i, dim in enumerate(dimensions)}, + ), + weight=W, + weight_cov=W_cov, + ) ) - else: - raise ValueError(f"Unsupported coordinate type: {type(coordinates)}") + return qv.concatenate(variants_list) diff --git a/adam_core/orbits/variants.py b/adam_core/orbits/variants.py index 67300fec..20627d0e 100644 --- a/adam_core/orbits/variants.py +++ b/adam_core/orbits/variants.py @@ -65,7 +65,7 @@ def create( variants_orbits : '~adam_core.orbits.variants.VariantOrbits' The variant orbits. """ - idx, W, W_cov, variant_coordinates = create_coordinate_variants( + variant_coordinates = create_coordinate_variants( orbits.coordinates, method=method, num_samples=num_samples, @@ -74,11 +74,11 @@ def create( kappa=kappa, ) return cls.from_kwargs( - orbit_id=pc.take(orbits.orbit_id, idx), - object_id=pc.take(orbits.object_id, idx), - weights=W, - weights_cov=W_cov, - coordinates=variant_coordinates, + orbit_id=pc.take(orbits.orbit_id, variant_coordinates.index), + object_id=pc.take(orbits.object_id, variant_coordinates.index), + weights=variant_coordinates.weight, + weights_cov=variant_coordinates.weight_cov, + coordinates=variant_coordinates.sample, ) def link_to_orbits(