From 880acb3c8f751ad40214287c7dcc82dc573730e4 Mon Sep 17 00:00:00 2001
From: Nicolas Tessore <n.tessore@ucl.ac.uk>
Date: Sat, 11 Nov 2023 13:35:40 +0000
Subject: [PATCH] ENH(io): TocFits class for FITS-backed tocdicts (#53)

Add a `TocFits` class that implements a FITS-backed mapping compatible
with `TocDict`. The base class is generic, concrete implementations are
`AlmFits`, `ClsFits`, and `MmsFits`.

Closes: #47
---
 heracles/io.py       | 264 +++++++++++++++++++++++++++++++++++--------
 heracles/twopoint.py |  41 +++++--
 tests/test_io.py     | 112 ++++++++++++++++++
 3 files changed, 358 insertions(+), 59 deletions(-)

diff --git a/heracles/io.py b/heracles/io.py
index d5a6e3d..a40c0f5 100644
--- a/heracles/io.py
+++ b/heracles/io.py
@@ -20,6 +20,11 @@
 
 import logging
 import os
+from collections.abc import MutableMapping
+from functools import partial
+from pathlib import Path
+from types import MappingProxyType
+from weakref import WeakValueDictionary
 
 import fitsio
 import healpy as hp
@@ -61,7 +66,29 @@ def _read_metadata(hdu):
     return md
 
 
-def _as_twopoint(arr, name):
+def _write_complex(fits, ext, arr):
+    """write complex-valued data to FITS table"""
+    # write the data
+    fits.write_table([arr.real, arr.imag], names=["real", "imag"], extname=ext)
+
+    # write the metadata
+    _write_metadata(fits[ext], arr.dtype.metadata)
+
+
+def _read_complex(fits, ext):
+    """read complex-valued data from FITS table"""
+    # read structured data as complex array
+    raw = fits[ext].read()
+    arr = np.empty(len(raw), dtype=complex)
+    arr.real = raw["real"]
+    arr.imag = raw["imag"]
+    del raw
+    # read and attach metadata
+    arr.dtype = np.dtype(arr.dtype, metadata=_read_metadata(fits[ext]))
+    return arr
+
+
+def _write_twopoint(fits, ext, arr, name):
     """convert two-point data (i.e. one L column) to structured array"""
 
     arr = np.asanyarray(arr)
@@ -89,6 +116,19 @@ def _as_twopoint(arr, name):
         arr["LMAX"] = arr["L"] + 1
         arr["W"] = 1
 
+    # write the twopoint data
+    fits.write_table(arr, extname=ext)
+
+    # write the metadata
+    _write_metadata(fits[ext], arr.dtype.metadata)
+
+
+def _read_twopoint(fits, ext):
+    """read two-point data from FITS"""
+    # read data from extension
+    arr = fits[ext].read()
+    # read and attach metadata
+    arr.dtype = np.dtype(arr.dtype, metadata=_read_metadata(fits[ext]))
     return arr
 
 
@@ -322,11 +362,8 @@ def write_alms(
             ext = f"ALM{almn}"
             almn += 1
 
-            # write the data
-            fits.write_table([alm.real, alm.imag], names=["real", "imag"], extname=ext)
-
-            # write the metadata
-            _write_metadata(fits[ext], alm.dtype.metadata)
+            # write the alm as structured data with metadata
+            _write_complex(fits, ext, alm)
 
             # write the TOC entry
             tocentry[0] = (ext, n, i)
@@ -361,18 +398,8 @@ def read_alms(filename, workdir=".", *, include=None, exclude=None):
 
             logger.info("reading %s alm for bin %s", n, i)
 
-            # read the alm from the extension
-            raw = fits[ext].read()
-            alm = np.empty(len(raw), dtype=complex)
-            alm.real = raw["real"]
-            alm.imag = raw["imag"]
-            del raw
-
-            # read and attach metadata
-            alm.dtype = np.dtype(alm.dtype, metadata=_read_metadata(fits[ext]))
-
-            # store in set of alms
-            alms[n, i] = alm
+            # read the alm from the extension and store in set of alms
+            alms[n, i] = _read_complex(fits, ext)
 
     logger.info("done with %d alms", len(alms))
 
@@ -428,14 +455,8 @@ def write_cls(filename, cls, *, clobber=False, workdir=".", include=None, exclud
             ext = f"CL{cln}"
             cln += 1
 
-            # get the data into structured format if not already
-            cl = _as_twopoint(cl, "CL")
-
-            # write the data columns
-            fits.write_table(cl, extname=ext)
-
-            # write the metadata
-            _write_metadata(fits[ext], cl.dtype.metadata)
+            # write the data in structured format
+            _write_twopoint(fits, ext, cl, "CL")
 
             # write the TOC entry
             tocentry[0] = (ext, k1, k2, i1, i2)
@@ -470,14 +491,8 @@ def read_cls(filename, workdir=".", *, include=None, exclude=None):
 
             logger.info("reading %s x %s cl for bins %s, %s", k1, k2, i1, i2)
 
-            # read the cl from the extension
-            cl = fits[ext].read()
-
-            # read and attach metadata
-            cl.dtype = np.dtype(cl.dtype, metadata=_read_metadata(fits[ext]))
-
-            # store in set of cls
-            cls[k1, k2, i1, i2] = cl
+            # read the cl from the extension and store in set of cls
+            cls[k1, k2, i1, i2] = _read_twopoint(fits, ext)
 
     logger.info("done with %d cls", len(cls))
 
@@ -533,14 +548,8 @@ def write_mms(filename, mms, *, clobber=False, workdir=".", include=None, exclud
             ext = f"MM{mmn}"
             mmn += 1
 
-            # get the data into structured format if not already
-            mm = _as_twopoint(mm, "MM")
-
-            # write the data columns
-            fits.write_table(mm, extname=ext)
-
-            # write the metadata
-            _write_metadata(fits[ext], mm.dtype.metadata)
+            # write the data in structured format
+            _write_twopoint(fits, ext, mm, "MM")
 
             # write the TOC entry
             tocentry[0] = (ext, n, i1, i2)
@@ -575,14 +584,8 @@ def read_mms(filename, workdir=".", *, include=None, exclude=None):
 
             logger.info("reading mixing matrix %s for bins %s, %s", n, i1, i2)
 
-            # read the mixing matrix from the extension
-            mm = fits[ext].read()
-
-            # read and attach metadata
-            mm.dtype = np.dtype(mm.dtype, metadata=_read_metadata(fits[ext]))
-
-            # store in set of mms
-            mms[n, i1, i2] = mm
+            # read the mixing matrix from the extension and store in set of mms
+            mms[n, i1, i2] = _read_twopoint(fits, ext)
 
     logger.info("done with %d mm(s)", len(mms))
 
@@ -711,3 +714,164 @@ def read_cov(filename, workdir=".", *, include=None, exclude=None):
 
     # return the toc dict of covariances
     return cov
+
+
+class TocFits(MutableMapping):
+    """A FITS-backed TocDict."""
+
+    tag = "EXT"
+    """Tag for FITS extensions."""
+
+    columns = {}
+    """Columns and their formats in the FITS table of contents."""
+
+    @staticmethod
+    def reader(fits, ext):
+        """Read data from FITS extension."""
+        return fits[ext].read()
+
+    @staticmethod
+    def writer(fits, ext, data):
+        """Write data to FITS extension."""
+        if data.dtype.names is None:
+            msg = "data must be structured array"
+            raise TypeError(msg)
+        fits.write_table(data, extname=ext)
+
+    @property
+    def fits(self):
+        """Return an opened FITS context manager."""
+        return fitsio.FITS(self.path, mode="rw", clobber=False)
+
+    @property
+    def toc(self):
+        """Return a view of the FITS table of contents."""
+        return MappingProxyType(self._toc)
+
+    def __init__(self, path, *, clobber=False):
+        self.path = Path(path)
+
+        # FITS extension for table of contents
+        self.ext = f"{self.tag.upper()}TOC"
+
+        # if new or overwriting, create an empty FITS with primary HDU
+        if not self.path.exists() or clobber:
+            with fitsio.FITS(self.path, mode="rw", clobber=True) as fits:
+                fits.write(None)
+
+        # reopen FITS for writing data
+        with self.fits as fits:
+            # write a new ToC extension if FITS doesn't already contain one
+            if self.ext not in fits:
+                fits.create_table_hdu(
+                    names=["EXT", *self.columns.keys()],
+                    formats=["10A", *self.columns.values()],
+                    extname=self.ext,
+                )
+
+                # get the dtype for ToC entries
+                self.dtype = fits[self.ext].get_rec_dtype()[0]
+
+                # empty ToC
+                self._toc = TocDict()
+            else:
+                # read existing ToC from FITS
+                toc = fits[self.ext].read()
+
+                # store the dtype for ToC entries
+                toc.dtype = toc.dtype
+
+                # store the ToC as a mapping
+                self._toc = TocDict({tuple(key): str(ext) for ext, *key in toc})
+
+        # set up a weakly-referenced cache for extension data
+        self._cache = WeakValueDictionary()
+
+    def __len__(self):
+        return len(self._toc)
+
+    def __iter__(self):
+        return iter(self._toc)
+
+    def __contains__(self, key):
+        if not isinstance(key, tuple):
+            key = (key,)
+        return key in self._toc
+
+    def __getitem__(self, key):
+        ext = self._toc[key]
+
+        # if a TocDict is returned, we have the result of a selection
+        if isinstance(ext, TocDict):
+            # make a new instance and copy attributes
+            selected = object.__new__(self.__class__)
+            selected.path = self.path
+            # shared cache since both instances read the same file
+            selected._cache = self._cache
+            # the new toc contains the result of the selection
+            selected._toc = ext
+            return selected
+
+        # a specific extension was requested, fetch data
+        data = self._cache.get(ext)
+        if data is None:
+            with self.fits as fits:
+                data = self.reader(fits, ext)
+            self._cache[ext] = data
+        return data
+
+    def __setitem__(self, key, value):
+        # keys are always tuples
+        if not isinstance(key, tuple):
+            key = (key,)
+
+        # check if an extension with the given key already exists
+        # otherwise, get the first free extension with the given tag
+        if key in self._toc:
+            ext = self._toc[key]
+        else:
+            extn = len(self._toc)
+            exts = set(self._toc.values())
+            while (ext := f"{self.tag.upper()}{extn}") in exts:
+                extn += 1
+
+        # write data using the class writer, and update ToC as necessary
+        with self.fits as fits:
+            self.writer(fits, ext, value)
+            if key not in self._toc:
+                tocentry = np.empty(1, dtype=self.dtype)
+                tocentry[0] = (ext, *key)
+                fits[self.ext].append(tocentry)
+                self._toc[key] = ext
+
+    def __delitem__(self, key):
+        # fitsio does not support deletion of extensions
+        msg = "deleting FITS extensions is not supported"
+        raise NotImplementedError(msg)
+
+
+class AlmFits(TocFits):
+    """FITS-backed mapping for alms."""
+
+    tag = "ALM"
+    columns = {"NAME": "10A", "BIN": "I"}
+    reader = staticmethod(_read_complex)
+    writer = staticmethod(_write_complex)
+
+
+class ClsFits(TocFits):
+    """FITS-backed mapping for cls."""
+
+    tag = "CL"
+    columns = {"NAME1": "10A", "NAME2": "10A", "BIN1": "I", "BIN2": "I"}
+    reader = staticmethod(_read_twopoint)
+    writer = partial(_write_twopoint, name=tag)
+
+
+class MmsFits(TocFits):
+    """FITS-backed mapping for mixing matrices."""
+
+    tag = "MM"
+    columns = {"NAME": "10A", "BIN1": "I", "BIN2": "I"}
+    reader = staticmethod(_read_twopoint)
+    writer = partial(_write_twopoint, name=tag)
diff --git a/heracles/twopoint.py b/heracles/twopoint.py
index 3235079..be44f09 100644
--- a/heracles/twopoint.py
+++ b/heracles/twopoint.py
@@ -41,7 +41,15 @@
 logger = logging.getLogger(__name__)
 
 
-def angular_power_spectra(alms, alms2=None, *, lmax=None, include=None, exclude=None):
+def angular_power_spectra(
+    alms,
+    alms2=None,
+    *,
+    lmax=None,
+    include=None,
+    exclude=None,
+    out=None,
+):
     """compute angular power spectra from a set of alms"""
 
     logger.info(
@@ -55,25 +63,34 @@ def angular_power_spectra(alms, alms2=None, *, lmax=None, include=None, exclude=
 
     # collect all alm combinations for computing cls
     if alms2 is None:
-        alm_pairs = combinations_with_replacement(alms.items(), 2)
+        pairs = combinations_with_replacement(alms, 2)
+        alms2 = alms
     else:
-        alm_pairs = product(alms.items(), alms2.items())
+        pairs = product(alms, alms2)
 
     # keep track of the twopoint combinations we have seen here
     twopoint_names = set()
 
+    # output tocdict, use given or empty
+    if out is None:
+        cls = TocDict()
+    else:
+        cls = out
+
     # compute cls for all alm pairs
     # do not compute duplicates
-    cls = TocDict()
-    for ((k1, i1), alm1), ((k2, i2), alm2) in alm_pairs:
+    for (k1, i1), (k2, i2) in pairs:
+        # skip duplicate cls in any order
+        if (k1, k2, i1, i2) in cls or (k2, k1, i2, i1) in cls:
+            continue
+
         # get the two-point code in standard order
         if (k1, k2) not in twopoint_names and (k2, k1) in twopoint_names:
             i1, i2 = i2, i1
             k1, k2 = k2, k1
-
-        # skip duplicate cls in any order
-        if (k1, k2, i1, i2) in cls or (k2, k1, i2, i1) in cls:
-            continue
+            swapped = True
+        else:
+            swapped = False
 
         # check if cl is skipped by explicit include or exclude list
         if not toc_match((k1, k2, i1, i2), include, exclude):
@@ -81,6 +98,12 @@ def angular_power_spectra(alms, alms2=None, *, lmax=None, include=None, exclude=
 
         logger.info("computing %s x %s cl for bins %s, %s", k1, k2, i1, i2)
 
+        # retrieve alms from keys; make sure swap is respected
+        # this is done only now because alms might lazy-load from file
+        alm1, alm2 = alms[k1, i1], alms2[k2, i2]
+        if swapped:
+            alm1, alm2 = alm2, alm1
+
         # compute the raw cl from the alms
         cl = hp.alm2cl(alm1, alm2, lmax_out=lmax)
 
diff --git a/tests/test_io.py b/tests/test_io.py
index 4c64702..46f4611 100644
--- a/tests/test_io.py
+++ b/tests/test_io.py
@@ -375,3 +375,115 @@ def test_read_mask_extra(
         extra_mask_name=mock_writemask_extra,
     )
     assert (mask == maps[:, ibin] * maps_extra[:]).all()
+
+
+def test_tocfits(tmp_path):
+    import fitsio
+    import numpy as np
+
+    from heracles.io import TocFits
+
+    class TestFits(TocFits):
+        tag = "test"
+        columns = {"col1": "I", "col2": "J"}
+
+    path = tmp_path / "test.fits"
+
+    assert not path.exists()
+
+    tocfits = TestFits(path, clobber=True)
+
+    assert path.exists()
+
+    with fitsio.FITS(path) as fits:
+        assert len(fits) == 2
+        toc = fits["TESTTOC"].read()
+    assert toc.dtype.names == ("EXT", "col1", "col2")
+    assert len(toc) == 0
+
+    assert len(tocfits) == 0
+    assert list(tocfits) == []
+    assert tocfits.toc == {}
+
+    data12 = np.zeros(5, dtype=[("X", float), ("Y", int)])
+    data22 = np.ones(5, dtype=[("X", float), ("Y", int)])
+
+    tocfits[1, 2] = data12
+
+    with fitsio.FITS(path) as fits:
+        assert len(fits) == 3
+        toc = fits["TESTTOC"].read()
+        assert len(toc) == 1
+        np.testing.assert_array_equal(fits["TEST0"].read(), data12)
+
+    assert len(tocfits) == 1
+    assert list(tocfits) == [(1, 2)]
+    assert tocfits.toc == {(1, 2): "TEST0"}
+    np.testing.assert_array_equal(tocfits[1, 2], data12)
+
+    tocfits[2, 2] = data22
+
+    with fitsio.FITS(path) as fits:
+        assert len(fits) == 4
+        toc = fits["TESTTOC"].read()
+        assert len(toc) == 2
+        np.testing.assert_array_equal(fits["TEST0"].read(), data12)
+        np.testing.assert_array_equal(fits["TEST1"].read(), data22)
+
+    assert len(tocfits) == 2
+    assert list(tocfits) == [(1, 2), (2, 2)]
+    assert tocfits.toc == {(1, 2): "TEST0", (2, 2): "TEST1"}
+    np.testing.assert_array_equal(tocfits[1, 2], data12)
+    np.testing.assert_array_equal(tocfits[2, 2], data22)
+
+    with pytest.raises(NotImplementedError):
+        del tocfits[1, 2]
+
+    del tocfits
+
+    tocfits2 = TestFits(path, clobber=False)
+
+    assert len(tocfits2) == 2
+    assert list(tocfits2) == [(1, 2), (2, 2)]
+    assert tocfits2.toc == {(1, 2): "TEST0", (2, 2): "TEST1"}
+    np.testing.assert_array_equal(tocfits2[1, 2], data12)
+    np.testing.assert_array_equal(tocfits2[2, 2], data22)
+
+
+def test_tocfits_is_lazy(tmp_path):
+    import fitsio
+
+    from heracles.io import TocFits
+
+    path = tmp_path / "bad.fits"
+
+    # test keys(), values(), and items() are not eagerly reading data
+    tocfits = TocFits(path, clobber=True)
+
+    # manually enter some non-existent rows into the ToC
+    assert tocfits._toc == {}
+    tocfits._toc[0,] = "BAD0"
+    tocfits._toc[1,] = "BAD1"
+    tocfits._toc[2,] = "BAD2"
+
+    # these should not error
+    tocfits.keys()
+    tocfits.values()
+    tocfits.items()
+
+    # contains and iteration are lazy
+    assert 0 in tocfits
+    assert list(tocfits) == [(0,), (1,), (2,)]
+
+    # subselection should work fine
+    selected = tocfits[...]
+    assert isinstance(selected, TocFits)
+    assert len(selected) == 3
+
+    # make sure nothing is in the FITS
+    with fitsio.FITS(path, "r") as fits:
+        assert len(fits) == 2
+
+    # make sure there are errors when acualising the generators
+    with pytest.raises(OSError):
+        list(tocfits.values())