Skip to content

Commit

Permalink
Define CoordinateVariants table inside create_coordinate_variants
Browse files Browse the repository at this point in the history
Motivated by @spenczar: "Or if you're insane, you could define the class inside the
function. I think this really would work!" (#41 (comment))

Co-authored-by: Spencer Nelson <[email protected]>
  • Loading branch information
moeyensj and spenczar committed Aug 22, 2023
1 parent 4e5165c commit 45709f2
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 113 deletions.
146 changes: 39 additions & 107 deletions adam_core/coordinates/variants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Literal, Tuple, Union
from typing import Literal, Union

import numpy as np
import quivr as qv
Expand Down Expand Up @@ -29,7 +29,7 @@ def create_coordinate_variants(
alpha: float = 1,
beta: float = 0,
kappa: float = 0,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, CoordinateType]:
) -> "CoordinateVariants":
"""
Sample and create variants for the given coordinates by sampling the covariance matrices.
There are three supported methods:
Expand Down Expand Up @@ -69,36 +69,38 @@ def create_coordinate_variants(
Returns
-------
idx : '~numpy.ndarray'
The index of the coordinate that each sample belongs to.
weights : '~numpy.ndarray'
Weights of each sample.
cov_weights : '~numpy.ndarray'
Weights of the samples to reconstruct covariance matrix.
samples : {'~adam_core.coordinates.cartesian.CartesianCoordinates',
'~adam_core.coordinates.keplerian.KeplerianCoordinates',
'~adam_core.coordinates.cometary.CometaryCoordinates',
'~adam_core.coordinates.spherical.SphericalCoordinates'}
The samples drawn from the coordinate covariance matrices.
variants : '~adam_core.coordinates.variants.VariantCoordinates'
The variant coordinates.
Raises
------
ValueError:
If the covariance matrices are all undefined.
If the input coordinates are not supported.
"""
idx_list = []
samples_list = []
weights_list = []
cov_weights_list = []
origins_list = []
times_list = []

if coordinates.covariance.is_all_nan():
raise ValueError(
"Cannot sample coordinate covariances when covariances are all undefined."
)

class CoordinateVariants(qv.Table):
index = qv.Int64Column()
sample = coordinates.as_column()
weight = qv.Float64Column()
weight_cov = qv.Float64Column()

if isinstance(coordinates, CartesianCoordinates):
dimensions = ["x", "y", "z", "vx", "vy", "vz"]
elif isinstance(coordinates, SphericalCoordinates):
dimensions = ["rho", "lon", "lat", "vrho", "vlon", "vlat"]
elif isinstance(coordinates, KeplerianCoordinates):
dimensions = ["a", "e", "i", "raan", "ap", "M"]
elif isinstance(coordinates, CometaryCoordinates):
dimensions = ["q", "e", "i", "raan", "ap", "tp"]
else:
raise ValueError(f"Unsupported coordinate type: {type(coordinates)}")

variants_list = []
for i, coordinate_i in enumerate(coordinates):

mean = coordinate_i.values[0]
Expand Down Expand Up @@ -146,92 +148,22 @@ def create_coordinate_variants(
else:
raise ValueError(f"Unknown coordinate covariance sampling method: {method}")

origins_list += [coordinate_i.origin for i in range(len(samples))]
times_list += [coordinate_i.time for i in range(len(samples))]
samples_list.append(samples)
weights_list.append(W)
cov_weights_list.append(W_cov)
idx_list.append(np.full(len(samples), i))

samples = np.concatenate(samples_list)
idx = np.concatenate(idx_list)
weights = np.concatenate(weights_list)
cov_weights = np.concatenate(cov_weights_list)
origins = qv.concatenate(origins_list)
times = qv.concatenate(times_list)

if isinstance(coordinates, CartesianCoordinates):
return (
idx,
weights,
cov_weights,
CartesianCoordinates.from_kwargs(
x=samples[:, 0],
y=samples[:, 1],
z=samples[:, 2],
vx=samples[:, 3],
vy=samples[:, 4],
vz=samples[:, 5],
time=times,
covariance=None,
origin=origins,
frame=coordinates.frame,
),
)
elif isinstance(coordinates, SphericalCoordinates):
return (
idx,
weights,
cov_weights,
SphericalCoordinates.from_kwargs(
rho=samples[:, 0],
lon=samples[:, 1],
lat=samples[:, 2],
vrho=samples[:, 3],
vlon=samples[:, 4],
vlat=samples[:, 5],
time=times,
covariance=None,
origin=origins,
frame=coordinates.frame,
),
)
elif isinstance(coordinates, KeplerianCoordinates):
return (
idx,
weights,
cov_weights,
KeplerianCoordinates.from_kwargs(
a=samples[:, 0],
e=samples[:, 1],
i=samples[:, 2],
raan=samples[:, 3],
ap=samples[:, 4],
M=samples[:, 5],
time=times,
covariance=None,
origin=origins,
frame=coordinates.frame,
),
)
elif isinstance(coordinates, CometaryCoordinates):
return (
idx,
weights,
cov_weights,
CometaryCoordinates.from_kwargs(
q=samples[:, 0],
e=samples[:, 1],
i=samples[:, 2],
raan=samples[:, 3],
ap=samples[:, 4],
tp=samples[:, 5],
time=times,
covariance=None,
origin=origins,
frame=coordinates.frame,
),
variants_list.append(
CoordinateVariants.from_kwargs(
index=np.full(len(samples), i),
sample=coordinates.from_kwargs(
origin=qv.concatenate(
[coordinate_i.origin for i in range(len(samples))]
),
time=qv.concatenate(
[coordinate_i.time for i in range(len(samples))]
),
frame=coordinate_i.frame,
**{dim: samples[:, i] for i, dim in enumerate(dimensions)},
),
weight=W,
weight_cov=W_cov,
)
)

else:
raise ValueError(f"Unsupported coordinate type: {type(coordinates)}")
return qv.concatenate(variants_list)
12 changes: 6 additions & 6 deletions adam_core/orbits/variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def create(
variants_orbits : '~adam_core.orbits.variants.VariantOrbits'
The variant orbits.
"""
idx, W, W_cov, variant_coordinates = create_coordinate_variants(
variant_coordinates = create_coordinate_variants(
orbits.coordinates,
method=method,
num_samples=num_samples,
Expand All @@ -74,11 +74,11 @@ def create(
kappa=kappa,
)
return cls.from_kwargs(
orbit_id=pc.take(orbits.orbit_id, idx),
object_id=pc.take(orbits.object_id, idx),
weights=W,
weights_cov=W_cov,
coordinates=variant_coordinates,
orbit_id=pc.take(orbits.orbit_id, variant_coordinates.index),
object_id=pc.take(orbits.object_id, variant_coordinates.index),
weights=variant_coordinates.weight,
weights_cov=variant_coordinates.weight_cov,
coordinates=variant_coordinates.sample,
)

def link_to_orbits(
Expand Down

0 comments on commit 45709f2

Please sign in to comment.