Skip to content

Commit

Permalink
Use cached_property and types
Browse files Browse the repository at this point in the history
  • Loading branch information
aarchiba committed Feb 10, 2024
1 parent fa790eb commit 4521272
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 84 deletions.
122 changes: 75 additions & 47 deletions src/pint/observatory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,17 @@
necessary.
"""

from copy import deepcopy
import os
import textwrap
from collections import defaultdict
from collections.abc import Callable
from copy import deepcopy
from io import StringIO
from pathlib import Path
from typing import Optional, Union

import astropy.coordinates
import astropy.time
import astropy.units as u
import numpy as np
from astropy.coordinates import EarthLocation
Expand Down Expand Up @@ -97,7 +100,7 @@ def _load_gps_clock():
)


def _load_bipm_clock(bipm_version):
def _load_bipm_clock(bipm_version: str):
bipm_version = bipm_version.lower()
if bipm_version not in _bipm_clock_versions:
try:
Expand Down Expand Up @@ -136,34 +139,40 @@ class Observatory:
position.
"""

fullname: str
aliases: list[str]
include_gps: bool
include_bipm: bool
bipm_version: str

# This is a dict containing all defined Observatory instances,
# keyed on standard observatory name.
_registry = {}
_registry: dict[str, "Observatory"] = {}

# This is a dict mapping any defined aliases to the corresponding
# standard name.
_alias_map = {}
_alias_map: dict[str, str] = {}

def __init__(
self,
name,
fullname=None,
aliases=None,
include_gps=True,
include_bipm=True,
bipm_version=bipm_default,
overwrite=False,
name: str,
fullname: Optional[str] = None,
aliases: Optional[list[str]] = None,
include_gps: bool = True,
include_bipm: bool = True,
bipm_version: str = bipm_default,
overwrite: bool = False,
):
self._name = name.lower()
self._aliases = (
self._name: str = name.lower()
self._aliases: list[str] = (
list(set(map(str.lower, aliases))) if aliases is not None else []
)
if aliases is not None:
Observatory._add_aliases(self, aliases)
self.fullname = fullname if fullname is not None else name
self.include_gps = include_gps
self.include_bipm = include_bipm
self.bipm_version = bipm_version
self.fullname: str = fullname if fullname is not None else name
self.include_gps: bool = include_gps
self.include_bipm: bool = include_bipm
self.bipm_version: str = bipm_version

if name.lower() in Observatory._registry:
if not overwrite:
Expand All @@ -175,16 +184,18 @@ def __init__(
Observatory._register(self, name)

@classmethod
def _register(cls, obs, name):
"""Add an observatory to the registry using the specified name
(which will be converted to lower case). If an existing observatory
def _register(cls, obs: "Observatory", name: str):
"""Add an observatory to the registry using the specified name (which will be converted to lower case).
If an existing observatory
of the same name exists, it will be replaced with the new one.
The Observatory instance's name attribute will be updated for
consistency."""
consistency.
"""
cls._registry[name.lower()] = obs

@classmethod
def _add_aliases(cls, obs, aliases):
def _add_aliases(cls, obs: "Observatory", aliases: list[str]):
"""Add aliases for the specified Observatory. Aliases
should be given as a list. If any of the new aliases are already in
use, they will be replaced. Aliases are not checked against the
Expand All @@ -196,14 +207,17 @@ def _add_aliases(cls, obs, aliases):
cls._alias_map[a.lower()] = obs.name

@staticmethod
def gps_correction(t, limits="warn"):
def gps_correction(t: astropy.time.Time, limits: str = "warn"):
"""Compute the GPS clock corrections for times t."""
log.info("Applying GPS to UTC clock correction (~few nanoseconds)")
_load_gps_clock()
assert _gps_clock is not None
return _gps_clock.evaluate(t, limits=limits)

@staticmethod
def bipm_correction(t, bipm_version=bipm_default, limits="warn"):
def bipm_correction(
t: astropy.time.Time, bipm_version: str = bipm_default, limits: str = "warn"
):
"""Compute the GPS clock corrections for times t."""
log.info(f"Applying TT(TAI) to TT({bipm_version}) clock correction (~27 us)")
tt2tai = 32.184 * 1e6 * u.us
Expand All @@ -214,7 +228,7 @@ def bipm_correction(t, bipm_version=bipm_default, limits="warn"):
)

@classmethod
def clear_registry(cls):
def clear_registry(cls) -> None:
"""Clear registry for ground-based observatories."""
cls._registry = {}
cls._alias_map = {}
Expand All @@ -229,7 +243,7 @@ def names(cls):
return cls._registry.keys()

@classmethod
def names_and_aliases(cls):
def names_and_aliases(cls) -> dict[str, list[str]]:
"""List all observatories and their aliases"""
import pint.observatory.topo_obs # noqa
import pint.observatory.special_locations # noqa
Expand All @@ -241,15 +255,15 @@ def names_and_aliases(cls):
# setter methods that update the registries appropriately.

@property
def name(self):
def name(self) -> str:
return self._name

@property
def aliases(self):
def aliases(self) -> list[str]:
return self._aliases

@classmethod
def get(cls, name):
def get(cls, name: str) -> "Observatory":
"""Returns the Observatory instance for the specified name/alias.
If the name has not been defined, an error will be raised. Aside
Expand Down Expand Up @@ -303,9 +317,12 @@ def get(cls, name):
# Any which raise NotImplementedError below must be implemented in
# derived classes.

def earth_location_itrf(self, time=None):
"""Returns observatory geocentric position as an astropy
EarthLocation object. For observatories where this is not
def earth_location_itrf(
self, time: Optional[astropy.time.Time] = None
) -> Union[None, np.ndarray]:
"""Returns observatory geocentric position as an astropy EarthLocation object.
For observatories where this is not
relevant, None can be returned.
The location is in the International Terrestrial Reference Frame (ITRF).
Expand All @@ -319,8 +336,9 @@ def earth_location_itrf(self, time=None):
"""
return None

def get_gcrs(self, t, ephem=None):
"""Return position vector of observatory in GCRS
def get_gcrs(self, t: astropy.time.Time, ephem=None):
"""Return position vector of observatory in GCRS.
t is an astropy.Time or array of astropy.Time objects
ephem is a link to an ephemeris file. Needed for SSB observatory
Returns a 3-vector of Quantities representing the position
Expand All @@ -329,14 +347,17 @@ def get_gcrs(self, t, ephem=None):
raise NotImplementedError

@property
def timescale(self):
"""Returns the timescale that TOAs from this observatory will be in,
once any clock corrections have been applied. This should be a
def timescale(self) -> str:
"""Returns the timescale that TOAs from this observatory will be in, once any clock corrections have been applied.
This should be a
string suitable to be passed directly to the scale argument of
astropy.time.Time()."""
raise NotImplementedError

def clock_corrections(self, t, limits="warn"):
def clock_corrections(
self, t: astropy.time.Time, limits: str = "warn"
) -> u.Quantity:
"""Compute clock corrections for a Time array.
Given an array-valued Time, return the clock corrections
Expand All @@ -356,7 +377,7 @@ def clock_corrections(self, t, limits="warn"):

return corr

def last_clock_correction_mjd(self):
def last_clock_correction_mjd(self) -> float:
"""Return the MJD of the last available clock correction.
Returns ``np.inf`` if no clock corrections are relevant.
Expand All @@ -365,6 +386,7 @@ def last_clock_correction_mjd(self):

if self.include_gps:
_load_gps_clock()
assert _gps_clock is not None
t = min(t, _gps_clock.last_correction_mjd())
if self.include_bipm:
_load_bipm_clock(self.bipm_version)
Expand All @@ -374,7 +396,13 @@ def last_clock_correction_mjd(self):
)
return t

def get_TDBs(self, t, method="default", ephem=None, options=None):
def get_TDBs(
self,
t: astropy.time.Time,
method: Union[str, Callable] = "default",
ephem: Optional[str] = None,
options: Optional[dict] = None,
):
"""This is a high level function for converting TOAs to TDB time scale.
Different method can be applied to obtain the result. Current supported
Expand Down Expand Up @@ -409,13 +437,13 @@ def get_TDBs(self, t, method="default", ephem=None, options=None):
t = Time([t])
if t.scale == "tdb":
return t
# Check the method. This pattern is from numpy minimize
meth = "_custom" if callable(method) else method.lower()
if options is None:
options = {}
if meth == "_custom":
if callable(method):
options = dict(options)
return method(t, **options)
else:
meth = method.lower()
if meth == "default":
return self._get_TDB_default(t, ephem)
elif meth == "ephemeris":
Expand All @@ -428,17 +456,17 @@ def get_TDBs(self, t, method="default", ephem=None, options=None):
else:
raise ValueError(f"Unknown method '{method}'.")

def _get_TDB_default(self, t, ephem):
def _get_TDB_default(self, t: astropy.time.Time, ephem):
return t.tdb

def _get_TDB_ephem(self, t, ephem):
def _get_TDB_ephem(self, t: astropy.time.Time, ephem):
"""Read the ephem TDB-TT column.
This column is provided by DE4XXt version of ephemeris.
"""
raise NotImplementedError

def posvel(self, t, ephem, group=None):
def posvel(self, t: astropy.time.Time, ephem, group=None):
"""Return observatory position and velocity for the given times.
Position is relative to solar system barycenter; times are
Expand All @@ -451,7 +479,7 @@ def posvel(self, t, ephem, group=None):


def get_observatory(
name, include_gps=None, include_bipm=None, bipm_version=bipm_default
name: str, include_gps=None, include_bipm=None, bipm_version: str = bipm_default
):
"""Convenience function to get observatory object with options.
Expand Down
Loading

0 comments on commit 4521272

Please sign in to comment.