Skip to content

Commit

Permalink
Add _assert_times_almost_equal
Browse files Browse the repository at this point in the history
  • Loading branch information
moeyensj committed Aug 10, 2023
1 parent bef633e commit f27cbc0
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 1 deletion.
5 changes: 5 additions & 0 deletions adam_core/propagator/pyoorb.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ..coordinates.times import Times
from ..orbits.orbits import Orbits
from .propagator import Propagator
from .utils import _assert_times_almost_equal


class OpenOrbTimescale(enum.Enum):
Expand Down Expand Up @@ -251,6 +252,10 @@ def _propagate_orbits(self, orbits: Orbits, times: Time) -> Orbits:
vz = states[:, 6]
mjd_tt = states[:, 8]

# Check to make sure the desired times are within an acceptable
# tolerance
_assert_times_almost_equal(mjd_tt, np.repeat(epochs_pyoorb[:, 0], len(orbits)))

# Convert output epochs to TDB
times_ = Time(mjd_tt, format="mjd", scale="tt")
times_ = times_.tdb
Expand Down
18 changes: 17 additions & 1 deletion adam_core/propagator/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import numpy as np
import pytest
from quivr import Float64Column, Table

from ..utils import _iterate_chunks
from ..utils import _assert_times_almost_equal, _iterate_chunks


class SampleTable(Table):
Expand Down Expand Up @@ -34,3 +35,18 @@ def test__iterate_chunks_table():
assert len(chunk) == 1
assert isinstance(chunk, SampleTable)
np.testing.assert_equal(chunk.a.to_numpy(), np.arange(i * 2, i * 2 + 1))


def test__assert_times_almost_equal():
have = np.array([1.0, 2.0, 3.0])
want = np.array([1.0, 2.0, 3.0])

_assert_times_almost_equal(have, want, tolerance=1.0)

with pytest.raises(ValueError):
have = np.array([1.0, 2.0, 3.0])
want = np.array([1.0, 2.0, 3.0])

# Offset have by 2 ms
have += 2 / 86800 / 1000
_assert_times_almost_equal(have, want, tolerance=1.0)
28 changes: 28 additions & 0 deletions adam_core/propagator/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import Iterable, Sequence

import numpy as np
import pyarrow as pa
from pyarrow import compute as pc

from ..orbits import Orbits

MILLISECOND_IN_DAYS = 1 / 86400 / 1000


def _iterate_chunks(iterable: Sequence, chunk_size: int) -> Iterable:
"""
Expand Down Expand Up @@ -70,3 +73,28 @@ def sort_propagated_orbits(propagated_orbits: Orbits) -> Orbits:
),
)
return propagated_orbits.take(indices)


def _assert_times_almost_equal(
have: np.ndarray, want: np.ndarray, tolerance: float = 0.1
):
"""
Raises a ValueError if the time arrays (in units of days such as MJD) are not within the
tolerance in milliseconds of each other.
Parameters
----------
have : `~numpy.ndarray`
Times (in units of days) to check.
want : `~numpy.ndarray`
Times (in units of days) to check.
Raises
------
ValueError: If the time arrays are not within the tolerance in milliseconds of each other.
"""
tolerance_in_days = tolerance * MILLISECOND_IN_DAYS

diff = np.abs(have - want)
if np.any(diff > tolerance_in_days):
raise ValueError(f"Times were not within {tolerance:.2f} ms of each other.")

0 comments on commit f27cbc0

Please sign in to comment.