diff --git a/src/adam_core/observers/observers.py b/src/adam_core/observers/observers.py index d9cc09ae..c07019e9 100644 --- a/src/adam_core/observers/observers.py +++ b/src/adam_core/observers/observers.py @@ -1,8 +1,11 @@ import warnings from typing import Union +import numpy as np +import numpy.typing as npt import pandas as pd import pyarrow as pa +import pyarrow.compute as pc import quivr as qv from mpc_obscodes import mpc_obscodes from typing_extensions import Self @@ -54,7 +57,9 @@ class Observers(qv.Table): coordinates = CartesianCoordinates.as_column() @classmethod - def from_codes(cls, codes: pa.Array, times: Timestamp) -> Self: + def from_codes( + cls, codes: Union[list, npt.NDArray[np.str_], pa.Array], times: Timestamp + ) -> Self: """ Create an Observers table from a list of codes and times. The codes and times do not need to be unique. The observer state will be calculated for each time @@ -62,7 +67,7 @@ def from_codes(cls, codes: pa.Array, times: Timestamp) -> Self: Parameters ---------- - codes : pa.Array (N) + codes : Union[list, npt.NDArray[np.str], pa.Array] (N) MPC observatory codes for which to find the states. times : Timestamp (N) Epochs for which to find the observatory locations. @@ -72,8 +77,96 @@ def from_codes(cls, codes: pa.Array, times: Timestamp) -> Self: observers : `~adam_core.observers.observers.Observers` (N) The observer and its state at each time. """ - assert len(codes) == len(times) - raise NotImplementedError + if len(codes) != len(times): + raise ValueError("codes and times must have the same length.") + + if not isinstance(codes, pa.Array): + codes = pa.array(codes, type=pa.large_string()) + + # Create a table with the codes and times and add + # and index column to track the original order + table = pa.Table.from_pydict( + { + "index": pa.array(range(len(codes)), type=pa.uint64()), + "code": codes, + "times.days": times.days, + "times.nanos": times.nanos, + } + ) + + # Expected observers schema with the addition of a + # column that tracks the original index + observers_schema = pa.schema( + [ + pa.field("code", pa.large_string(), nullable=False), + pa.field( + "coordinates", + pa.struct( + [ + pa.field("x", pa.float64()), + pa.field("y", pa.float64()), + pa.field("z", pa.float64()), + pa.field("vx", pa.float64()), + pa.field("vy", pa.float64()), + pa.field("vz", pa.float64()), + pa.field( + "time", + pa.struct( + [ + pa.field("days", pa.int64()), + pa.field("nanos", pa.int64()), + ] + ), + ), + pa.field( + "covariance", + pa.struct( + [pa.field("values", pa.large_list(pa.float64()))] + ), + ), + pa.field( + "origin", + pa.struct([pa.field("code", pa.large_string())]), + ), + ] + ), + ), + pa.field("index", pa.uint64()), + ], + metadata={ + "coordinates.time.scale": times.scale, + "coordinates.frame": "ecliptic", + }, + ) + + # Create an empty table with the expected schema + observers_table = observers_schema.empty_table() + + # Loop through each unique code and calculate the observer's + # state for each time (these can be non-unique as cls.from_code + # will handle this) + for code in table["code"].unique(): + + times_code = table.filter(pc.equal(table["code"], code)) + + observers = cls.from_code( + code.as_py(), + Timestamp.from_kwargs( + days=times_code["times.days"], + nanos=times_code["times.nanos"], + scale=times.scale, + ), + ) + + observers_table_i = observers.table.append_column( + "index", times_code["index"] + ) + observers_table = pa.concat_tables( + [observers_table, observers_table_i] + ).combine_chunks() + + observers_table = observers_table.sort_by(("index")).drop_columns(["index"]) + return cls.from_pyarrow(observers_table) @classmethod def from_code(cls, code: Union[str, OriginCodes], times: Timestamp) -> Self: diff --git a/src/adam_core/observers/tests/test_observers.py b/src/adam_core/observers/tests/test_observers.py new file mode 100644 index 00000000..be75e007 --- /dev/null +++ b/src/adam_core/observers/tests/test_observers.py @@ -0,0 +1,61 @@ +import pyarrow as pa +import pyarrow.compute as pc +import pytest + +from ...time import Timestamp +from ..observers import Observers + + +@pytest.fixture +def codes_times() -> tuple[pa.Array, Timestamp]: + codes = pa.array( + ["500", "X05", "I41", "X05", "I41", "W84", "500"], + ) + + times = Timestamp.from_kwargs( + days=[59000, 59001, 59002, 59003, 59004, 59005, 59006], + nanos=[0, 0, 0, 0, 0, 0, 0], + scale="tdb", + ) + return codes, times + + +def test_Observers_from_codes(codes_times) -> None: + # Test that observers from code returns the correct number of observers + # and in the order that they were requested + codes, times = codes_times + + observers = Observers.from_codes(codes, times) + assert len(observers) == 7 + assert pc.all(pc.equal(observers.code, codes)).as_py() + assert pc.all(pc.equal(observers.coordinates.time.days, times.days)).as_py() + assert pc.all(pc.equal(observers.coordinates.time.nanos, times.nanos)).as_py() + + +def test_Observers_from_codes_non_pyarrow(codes_times) -> None: + # Test that observers from code returns the correct number of observers + # and in the order that they were requested + codes, times = codes_times + + observers = Observers.from_codes(codes.to_numpy(zero_copy_only=False), times) + assert len(observers) == 7 + assert pc.all(pc.equal(observers.code, codes)).as_py() + assert pc.all(pc.equal(observers.coordinates.time.days, times.days)).as_py() + assert pc.all(pc.equal(observers.coordinates.time.nanos, times.nanos)).as_py() + + observers = Observers.from_codes(codes.to_pylist(), times) + assert len(observers) == 7 + assert pc.all(pc.equal(observers.code, codes)).as_py() + assert pc.all(pc.equal(observers.coordinates.time.days, times.days)).as_py() + assert pc.all(pc.equal(observers.coordinates.time.nanos, times.nanos)).as_py() + + +def test_Observers_from_codes_raises(codes_times) -> None: + # Test that observers from code raises an error if the codes and times + # are not the same length + codes, times = codes_times + + with pytest.raises(ValueError, match="codes and times must have the same length."): + Observers.from_codes(codes[:3], times) + with pytest.raises(ValueError, match="codes and times must have the same length."): + Observers.from_codes(codes, times[:3])