diff --git a/docs/changes/2600.feature.rst b/docs/changes/2600.feature.rst new file mode 100644 index 00000000000..1c996397c6d --- /dev/null +++ b/docs/changes/2600.feature.rst @@ -0,0 +1 @@ +Add Interpolator class to generalize the PointingInterpolator in the io collection. diff --git a/src/ctapipe/io/__init__.py b/src/ctapipe/io/__init__.py index afe96f39430..28603643160 100644 --- a/src/ctapipe/io/__init__.py +++ b/src/ctapipe/io/__init__.py @@ -18,6 +18,10 @@ from .datawriter import DATA_MODEL_VERSION, DataWriter +from .interpolation import ( + Interpolator, + PointingInterpolator, +) __all__ = [ "HDF5TableWriter", @@ -36,4 +40,6 @@ "DataWriter", "DATA_MODEL_VERSION", "get_hdf5_datalevels", + "Interpolator", + "PointingInterpolator", ] diff --git a/src/ctapipe/io/hdf5eventsource.py b/src/ctapipe/io/hdf5eventsource.py index afbe2153d2c..eb91b216c27 100644 --- a/src/ctapipe/io/hdf5eventsource.py +++ b/src/ctapipe/io/hdf5eventsource.py @@ -52,7 +52,7 @@ from .datalevels import DataLevel from .eventsource import EventSource from .hdf5tableio import HDF5TableReader -from .pointing import PointingInterpolator +from .interpolation import PointingInterpolator from .tableloader import DL2_SUBARRAY_GROUP, DL2_TELESCOPE_GROUP, POINTING_GROUP __all__ = ["HDF5EventSource"] diff --git a/src/ctapipe/io/interpolation.py b/src/ctapipe/io/interpolation.py new file mode 100644 index 00000000000..82de9f0bd19 --- /dev/null +++ b/src/ctapipe/io/interpolation.py @@ -0,0 +1,182 @@ +from abc import ABCMeta, abstractmethod +from typing import Any + +import astropy.units as u +import numpy as np +import tables +from astropy.time import Time +from scipy.interpolate import interp1d + +from ctapipe.core import Component, traits + +from .astropy_helpers import read_table + + +class Interpolator(Component, metaclass=ABCMeta): + """ + Interpolator parent class. + + Parameters + ---------- + h5file : None | tables.File + A open hdf5 file with read access. + """ + + bounds_error = traits.Bool( + default_value=True, + help="If true, raises an exception when trying to extrapolate out of the given table", + ).tag(config=True) + + extrapolate = traits.Bool( + help="If bounds_error is False, this flag will specify whether values outside" + "the available values are filled with nan (False) or extrapolated (True).", + default_value=False, + ).tag(config=True) + + telescope_data_group = None + required_columns = set() + expected_units = {} + + def __init__(self, h5file=None, **kwargs): + super().__init__(**kwargs) + + if h5file is not None and not isinstance(h5file, tables.File): + raise TypeError("h5file must be a tables.File") + self.h5file = h5file + + self.interp_options: dict[str, Any] = dict(assume_sorted=True, copy=False) + if self.bounds_error: + self.interp_options["bounds_error"] = True + elif self.extrapolate: + self.interp_options["bounds_error"] = False + self.interp_options["fill_value"] = "extrapolate" + else: + self.interp_options["bounds_error"] = False + self.interp_options["fill_value"] = np.nan + + self._interpolators = {} + + @abstractmethod + def add_table(self, tel_id, input_table): + """ + Add a table to this interpolator + This method reads input tables and creates instances of the needed interpolators + to be added to _interpolators. The first index of _interpolators needs to be + tel_id, the second needs to be the name of the parameter that is to be interpolated + + Parameters + ---------- + tel_id : int + Telescope id + input_table : astropy.table.Table + Table of pointing values, expected columns + are always ``time`` as ``Time`` column and + other columns for the data that is to be interpolated + """ + + pass + + def _check_tables(self, input_table): + missing = self.required_columns - set(input_table.colnames) + if len(missing) > 0: + raise ValueError(f"Table is missing required column(s): {missing}") + for col in self.expected_units: + unit = input_table[col].unit + if unit is None: + if self.expected_units[col] is not None: + raise ValueError( + f"{col} must have units compatible with '{self.expected_units[col].name}'" + ) + elif not self.expected_units[col].is_equivalent(unit): + if self.expected_units[col] is None: + raise ValueError(f"{col} must have units compatible with 'None'") + else: + raise ValueError( + f"{col} must have units compatible with '{self.expected_units[col].name}'" + ) + + def _check_interpolators(self, tel_id): + if tel_id not in self._interpolators: + if self.h5file is not None: + self._read_parameter_table(tel_id) # might need to be removed + else: + raise KeyError(f"No table available for tel_id {tel_id}") + + def _read_parameter_table(self, tel_id): + input_table = read_table( + self.h5file, + f"{self.telescope_data_group}/tel_{tel_id:03d}", + ) + self.add_table(tel_id, input_table) + + +class PointingInterpolator(Interpolator): + """ + Interpolator for pointing and pointing correction data + """ + + telescope_data_group = "/dl0/monitoring/telescope/pointing" + required_columns = frozenset(["time", "azimuth", "altitude"]) + expected_units = {"azimuth": u.rad, "altitude": u.rad} + + def __call__(self, tel_id, time): + """ + Interpolate alt/az for given time and tel_id. + + Parameters + ---------- + tel_id : int + telescope id + time : astropy.time.Time + time for which to interpolate the pointing + + Returns + ------- + altitude : astropy.units.Quantity[deg] + interpolated altitude angle + azimuth : astropy.units.Quantity[deg] + interpolated azimuth angle + """ + + self._check_interpolators(tel_id) + + mjd = time.tai.mjd + az = u.Quantity(self._interpolators[tel_id]["az"](mjd), u.rad, copy=False) + alt = u.Quantity(self._interpolators[tel_id]["alt"](mjd), u.rad, copy=False) + return alt, az + + def add_table(self, tel_id, input_table): + """ + Add a table to this interpolator + + Parameters + ---------- + tel_id : int + Telescope id + input_table : astropy.table.Table + Table of pointing values, expected columns + are ``time`` as ``Time`` column, ``azimuth`` and ``altitude`` + as quantity columns for pointing and pointing correction data. + """ + + self._check_tables(input_table) + + if not isinstance(input_table["time"], Time): + raise TypeError("'time' column of pointing table must be astropy.time.Time") + + input_table = input_table.copy() + input_table.sort("time") + + az = input_table["azimuth"].quantity.to_value(u.rad) + # prepare azimuth for interpolation by "unwrapping": i.e. turning + # [359, 1] into [359, 361]. This assumes that if we get values like + # [359, 1] the telescope moved 2 degrees through 0, not 358 degrees + # the other way around. This should be true for all telescopes given + # the sampling speed of pointing values and their maximum movement speed. + # No telescope can turn more than 180° in 2 seconds. + az = np.unwrap(az) + alt = input_table["altitude"].quantity.to_value(u.rad) + mjd = input_table["time"].tai.mjd + self._interpolators[tel_id] = {} + self._interpolators[tel_id]["az"] = interp1d(mjd, az, **self.interp_options) + self._interpolators[tel_id]["alt"] = interp1d(mjd, alt, **self.interp_options) diff --git a/src/ctapipe/io/pointing.py b/src/ctapipe/io/pointing.py deleted file mode 100644 index 412c0510a98..00000000000 --- a/src/ctapipe/io/pointing.py +++ /dev/null @@ -1,137 +0,0 @@ -from typing import Any - -import astropy.units as u -import numpy as np -import tables -from astropy.time import Time -from scipy.interpolate import interp1d - -from ctapipe.core import Component, traits - -from .astropy_helpers import read_table - - -class PointingInterpolator(Component): - """ - Interpolate pointing from a monitoring table to a given timestamp. - - Parameters - ---------- - h5file : None | tables.File - A open hdf5 file with read access. - The monitoring table is expected to be stored in that file at - ``/dl0/monitoring/telescope/pointing/tel_{tel_id:03d}`` - - If not given, monitoring tables can be added via `PointingInterpolator.add_table`. - """ - - bounds_error = traits.Bool( - default_value=True, - help="If true, raises an exception when trying to extrapolate out of the given table", - ).tag(config=True) - - extrapolate = traits.Bool( - help="If bounds_error is False, this flag will specify whether values outside" - "the available values are filled with nan (False) or extrapolated (True).", - default_value=False, - ).tag(config=True) - - def __init__(self, h5file=None, **kwargs): - super().__init__(**kwargs) - - if h5file is not None and not isinstance(h5file, tables.File): - raise TypeError("h5file must be a tables.File") - self.h5file = h5file - - self.interp_options: dict[str, Any] = dict(assume_sorted=True, copy=False) - if self.bounds_error: - self.interp_options["bounds_error"] = True - elif self.extrapolate: - self.interp_options["bounds_error"] = False - self.interp_options["fill_value"] = "extrapolate" - else: - self.interp_options["bounds_error"] = False - self.interp_options["fill_value"] = np.nan - - self._alt_interpolators = {} - self._az_interpolators = {} - - def add_table(self, tel_id, pointing_table): - """ - Add a table to this interpolator - - Parameters - ---------- - tel_id : int - Telescope id - pointing_table : astropy.table.Table - Table of pointing values, expected columns - are ``time`` as ``Time`` column, ``azimuth`` and ``altitude`` - as quantity columns. - """ - missing = {"time", "azimuth", "altitude"} - set(pointing_table.colnames) - if len(missing) > 0: - raise ValueError(f"Table is missing required column(s): {missing}") - - if not isinstance(pointing_table["time"], Time): - raise TypeError("'time' column of pointing table must be astropy.time.Time") - - for col in ("azimuth", "altitude"): - unit = pointing_table[col].unit - if unit is None or not u.rad.is_equivalent(unit): - raise ValueError(f"{col} must have units compatible with 'rad'") - - # sort first, so it's not done twice for each interpolator - pointing_table.sort("time") - # interpolate in mjd TAI. Float64 mjd is precise enough for pointing - # and TAI is contiguous, so no issues with leap seconds. - mjd = pointing_table["time"].tai.mjd - - az = pointing_table["azimuth"].quantity.to_value(u.rad) - # prepare azimuth for interpolation by "unwrapping": i.e. turning - # [359, 1] into [359, 361]. This assumes that if we get values like - # [359, 1] the telescope moved 2 degrees through 0, not 358 degrees - # the other way around. This should be true for all telescopes given - # the sampling speed of pointing values and their maximum movement speed. - # No telescope can turn more than 180° in 2 seconds. - az = np.unwrap(az) - alt = pointing_table["altitude"].quantity.to_value(u.rad) - - self._az_interpolators[tel_id] = interp1d(mjd, az, **self.interp_options) - self._alt_interpolators[tel_id] = interp1d(mjd, alt, **self.interp_options) - - def _read_pointing_table(self, tel_id): - pointing_table = read_table( - self.h5file, - f"/dl0/monitoring/telescope/pointing/tel_{tel_id:03d}", - ) - self.add_table(tel_id, pointing_table) - - def __call__(self, tel_id, time): - """ - Interpolate alt/az for given time and tel_id. - - Parameters - ---------- - tel_id : int - telescope id - time : astropy.time.Time - time for which to interpolate the pointing - - Returns - ------- - altitude : astropy.units.Quantity[deg] - interpolated altitude angle - azimuth : astropy.units.Quantity[deg] - interpolated azimuth angle - """ - if tel_id not in self._az_interpolators: - if self.h5file is not None: - self._read_pointing_table(tel_id) - else: - raise KeyError(f"No pointing table available for tel_id {tel_id}") - - mjd = time.tai.mjd - az = u.Quantity(self._az_interpolators[tel_id](mjd), u.rad, copy=False) - alt = u.Quantity(self._alt_interpolators[tel_id](mjd), u.rad, copy=False) - return alt, az diff --git a/src/ctapipe/io/tableloader.py b/src/ctapipe/io/tableloader.py index 36a954fb7de..6da74b6f081 100644 --- a/src/ctapipe/io/tableloader.py +++ b/src/ctapipe/io/tableloader.py @@ -14,7 +14,7 @@ from ..core import Component, Provenance, traits from ..instrument import FocalLengthKind, SubarrayDescription from .astropy_helpers import join_allow_empty, read_table -from .pointing import PointingInterpolator +from .interpolation import PointingInterpolator __all__ = ["TableLoader"] diff --git a/src/ctapipe/io/tests/test_pointing.py b/src/ctapipe/io/tests/test_interpolator.py similarity index 65% rename from src/ctapipe/io/tests/test_pointing.py rename to src/ctapipe/io/tests/test_interpolator.py index 10913053c18..02f4c4ce306 100644 --- a/src/ctapipe/io/tests/test_pointing.py +++ b/src/ctapipe/io/tests/test_interpolator.py @@ -5,37 +5,15 @@ from astropy.table import Table from astropy.time import Time +from ctapipe.io.interpolation import ( + PointingInterpolator, +) -def test_simple(): - """Test pointing interpolation""" - from ctapipe.io.pointing import PointingInterpolator - - t0 = Time("2022-01-01T00:00:00") - - table = Table( - { - "time": t0 + np.arange(0.0, 10.1, 2.0) * u.s, - "azimuth": np.linspace(0.0, 10.0, 6) * u.deg, - "altitude": np.linspace(70.0, 60.0, 6) * u.deg, - }, - ) - - interpolator = PointingInterpolator() - interpolator.add_table(1, table) - - alt, az = interpolator(tel_id=1, time=t0 + 1 * u.s) - assert u.isclose(alt, 69 * u.deg) - assert u.isclose(az, 1 * u.deg) - - with pytest.raises(KeyError): - interpolator(tel_id=2, time=t0 + 1 * u.s) +t0 = Time("2022-01-01T00:00:00") def test_azimuth_switchover(): """Test pointing interpolation""" - from ctapipe.io.pointing import PointingInterpolator - - t0 = Time("2022-01-01T00:00:00") table = Table( { @@ -55,7 +33,6 @@ def test_azimuth_switchover(): def test_invalid_input(): """Test invalid pointing tables raise nice errors""" - from ctapipe.io.pointing import PointingInterpolator wrong_time = Table( { @@ -91,10 +68,8 @@ def test_invalid_input(): def test_hdf5(tmp_path): + """Test writing interpolated data to file""" from ctapipe.io import write_table - from ctapipe.io.pointing import PointingInterpolator - - t0 = Time("2022-01-01T00:00:00") table = Table( { @@ -115,11 +90,8 @@ def test_hdf5(tmp_path): def test_bounds(): """Test invalid pointing tables raise nice errors""" - from ctapipe.io.pointing import PointingInterpolator - t0 = Time("2022-01-01T00:00:00") - - table = Table( + table_pointing = Table( { "time": t0 + np.arange(0.0, 10.1, 2.0) * u.s, "azimuth": np.linspace(0.0, 10.0, 6) * u.deg, @@ -127,28 +99,36 @@ def test_bounds(): }, ) - interpolator = PointingInterpolator() - interpolator.add_table(1, table) + interpolator_pointing = PointingInterpolator() + interpolator_pointing.add_table(1, table_pointing) + error_message = "below the interpolation range" - with pytest.raises(ValueError, match="below the interpolation range"): - interpolator(tel_id=1, time=t0 - 0.1 * u.s) + with pytest.raises(ValueError, match=error_message): + interpolator_pointing(tel_id=1, time=t0 - 0.1 * u.s) with pytest.raises(ValueError, match="above the interpolation range"): - interpolator(tel_id=1, time=t0 + 10.2 * u.s) + interpolator_pointing(tel_id=1, time=t0 + 10.2 * u.s) - interpolator = PointingInterpolator(bounds_error=False) - interpolator.add_table(1, table) + alt, az = interpolator_pointing(tel_id=1, time=t0 + 1 * u.s) + assert u.isclose(alt, 69 * u.deg) + assert u.isclose(az, 1 * u.deg) + + with pytest.raises(KeyError): + interpolator_pointing(tel_id=2, time=t0 + 1 * u.s) + + interpolator_pointing = PointingInterpolator(bounds_error=False) + interpolator_pointing.add_table(1, table_pointing) for dt in (-0.1, 10.1) * u.s: - alt, az = interpolator(tel_id=1, time=t0 + dt) + alt, az = interpolator_pointing(tel_id=1, time=t0 + dt) assert np.isnan(alt.value) assert np.isnan(az.value) - interpolator = PointingInterpolator(bounds_error=False, extrapolate=True) - interpolator.add_table(1, table) - alt, az = interpolator(tel_id=1, time=t0 - 1 * u.s) + interpolator_pointing = PointingInterpolator(bounds_error=False, extrapolate=True) + interpolator_pointing.add_table(1, table_pointing) + alt, az = interpolator_pointing(tel_id=1, time=t0 - 1 * u.s) assert u.isclose(alt, 71 * u.deg) assert u.isclose(az, -1 * u.deg) - alt, az = interpolator(tel_id=1, time=t0 + 11 * u.s) + alt, az = interpolator_pointing(tel_id=1, time=t0 + 11 * u.s) assert u.isclose(alt, 59 * u.deg) assert u.isclose(az, 11 * u.deg)