Skip to content

Commit

Permalink
Merge pull request #35 from B612-Asteroid-Institute/jm/pyoorb-time-ch…
Browse files Browse the repository at this point in the history
…ecks

Check if times returned from PYOORB are close to the input times
  • Loading branch information
moeyensj authored Aug 10, 2023
2 parents 820368e + 627540b commit 2ee365b
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 11 deletions.
22 changes: 12 additions & 10 deletions adam_core/propagator/pyoorb.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ..coordinates.times import Times
from ..orbits.orbits import Orbits
from .propagator import Propagator
from .utils import _assert_times_almost_equal


class OpenOrbTimescale(enum.Enum):
Expand Down Expand Up @@ -208,7 +209,7 @@ def _propagate_orbits(self, orbits: Orbits, times: Time) -> Orbits:
# Convert orbits into PYOORB format
orbits_pyoorb = self._configure_orbits(
orbits.coordinates.values,
orbits.coordinates.times.to_astropy().tt.mjd,
orbits.coordinates.time.to_astropy().tt.mjd,
OpenOrbOrbitType.CARTESIAN,
OpenOrbTimescale.TT,
magnitude=None,
Expand All @@ -230,6 +231,9 @@ def _propagate_orbits(self, orbits: Orbits, times: Time) -> Orbits:
)
states_list.append(orbits_pyoorb_i)

if err != 0:
raise RuntimeError(f"PYOORB propagation failed with error code {err}.")

# Convert list of new states into a pandas data frame
# These states at the moment will always be return as cartesian
# state vectors
Expand All @@ -251,19 +255,17 @@ def _propagate_orbits(self, orbits: Orbits, times: Time) -> Orbits:
vz = states[:, 6]
mjd_tt = states[:, 8]

# Check to make sure the desired times are within an acceptable
# tolerance
_assert_times_almost_equal(mjd_tt, np.repeat(epochs_pyoorb[:, 0], len(orbits)))

# Convert output epochs to TDB
times_ = Time(mjd_tt, format="mjd", scale="tt")
times_ = times_.tdb

if orbits.object_ids is not None:
object_ids = orbits.object_ids.to_numpy(zero_copy_only=False)[orbit_ids_]
else:
object_ids = None

if orbits.orbit_ids is not None:
orbit_ids = orbits.orbit_ids.to_numpy(zero_copy_only=False)[orbit_ids_]
else:
orbit_ids = None
# Map the object and orbit IDs back to the input arrays
object_ids = orbits.object_id.to_numpy(zero_copy_only=False)[orbit_ids_]
orbit_ids = orbits.orbit_id.to_numpy(zero_copy_only=False)[orbit_ids_]

propagated_orbits = Orbits.from_kwargs(
orbit_id=orbit_ids,
Expand Down
18 changes: 17 additions & 1 deletion adam_core/propagator/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import numpy as np
import pytest
from quivr import Float64Column, Table

from ..utils import _iterate_chunks
from ..utils import _assert_times_almost_equal, _iterate_chunks


class SampleTable(Table):
Expand Down Expand Up @@ -34,3 +35,18 @@ def test__iterate_chunks_table():
assert len(chunk) == 1
assert isinstance(chunk, SampleTable)
np.testing.assert_equal(chunk.a.to_numpy(), np.arange(i * 2, i * 2 + 1))


def test__assert_times_almost_equal():
have = np.array([1.0, 2.0, 3.0])
want = np.array([1.0, 2.0, 3.0])

_assert_times_almost_equal(have, want, tolerance=1.0)

with pytest.raises(ValueError):
have = np.array([1.0, 2.0, 3.0])
want = np.array([1.0, 2.0, 3.0])

# Offset have by 2 ms
have += 2 / 86800 / 1000
_assert_times_almost_equal(have, want, tolerance=1.0)
28 changes: 28 additions & 0 deletions adam_core/propagator/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import Iterable, Sequence

import numpy as np
import pyarrow as pa
from pyarrow import compute as pc

from ..orbits import Orbits

MILLISECOND_IN_DAYS = 1 / 86400 / 1000


def _iterate_chunks(iterable: Sequence, chunk_size: int) -> Iterable:
"""
Expand Down Expand Up @@ -70,3 +73,28 @@ def sort_propagated_orbits(propagated_orbits: Orbits) -> Orbits:
),
)
return propagated_orbits.take(indices)


def _assert_times_almost_equal(
have: np.ndarray, want: np.ndarray, tolerance: float = 0.1
):
"""
Raises a ValueError if the time arrays (in units of days such as MJD) are not within the
tolerance in milliseconds of each other.
Parameters
----------
have : `~numpy.ndarray`
Times (in units of days) to check.
want : `~numpy.ndarray`
Times (in units of days) to check.
Raises
------
ValueError: If the time arrays are not within the tolerance in milliseconds of each other.
"""
tolerance_in_days = tolerance * MILLISECOND_IN_DAYS

diff = np.abs(have - want)
if np.any(diff > tolerance_in_days):
raise ValueError(f"Times were not within {tolerance:.2f} ms of each other.")

0 comments on commit 2ee365b

Please sign in to comment.