From 880acb3c8f751ad40214287c7dcc82dc573730e4 Mon Sep 17 00:00:00 2001 From: Nicolas Tessore <n.tessore@ucl.ac.uk> Date: Sat, 11 Nov 2023 13:35:40 +0000 Subject: [PATCH] ENH(io): TocFits class for FITS-backed tocdicts (#53) Add a `TocFits` class that implements a FITS-backed mapping compatible with `TocDict`. The base class is generic, concrete implementations are `AlmFits`, `ClsFits`, and `MmsFits`. Closes: #47 --- heracles/io.py | 264 +++++++++++++++++++++++++++++++++++-------- heracles/twopoint.py | 41 +++++-- tests/test_io.py | 112 ++++++++++++++++++ 3 files changed, 358 insertions(+), 59 deletions(-) diff --git a/heracles/io.py b/heracles/io.py index d5a6e3d..a40c0f5 100644 --- a/heracles/io.py +++ b/heracles/io.py @@ -20,6 +20,11 @@ import logging import os +from collections.abc import MutableMapping +from functools import partial +from pathlib import Path +from types import MappingProxyType +from weakref import WeakValueDictionary import fitsio import healpy as hp @@ -61,7 +66,29 @@ def _read_metadata(hdu): return md -def _as_twopoint(arr, name): +def _write_complex(fits, ext, arr): + """write complex-valued data to FITS table""" + # write the data + fits.write_table([arr.real, arr.imag], names=["real", "imag"], extname=ext) + + # write the metadata + _write_metadata(fits[ext], arr.dtype.metadata) + + +def _read_complex(fits, ext): + """read complex-valued data from FITS table""" + # read structured data as complex array + raw = fits[ext].read() + arr = np.empty(len(raw), dtype=complex) + arr.real = raw["real"] + arr.imag = raw["imag"] + del raw + # read and attach metadata + arr.dtype = np.dtype(arr.dtype, metadata=_read_metadata(fits[ext])) + return arr + + +def _write_twopoint(fits, ext, arr, name): """convert two-point data (i.e. one L column) to structured array""" arr = np.asanyarray(arr) @@ -89,6 +116,19 @@ def _as_twopoint(arr, name): arr["LMAX"] = arr["L"] + 1 arr["W"] = 1 + # write the twopoint data + fits.write_table(arr, extname=ext) + + # write the metadata + _write_metadata(fits[ext], arr.dtype.metadata) + + +def _read_twopoint(fits, ext): + """read two-point data from FITS""" + # read data from extension + arr = fits[ext].read() + # read and attach metadata + arr.dtype = np.dtype(arr.dtype, metadata=_read_metadata(fits[ext])) return arr @@ -322,11 +362,8 @@ def write_alms( ext = f"ALM{almn}" almn += 1 - # write the data - fits.write_table([alm.real, alm.imag], names=["real", "imag"], extname=ext) - - # write the metadata - _write_metadata(fits[ext], alm.dtype.metadata) + # write the alm as structured data with metadata + _write_complex(fits, ext, alm) # write the TOC entry tocentry[0] = (ext, n, i) @@ -361,18 +398,8 @@ def read_alms(filename, workdir=".", *, include=None, exclude=None): logger.info("reading %s alm for bin %s", n, i) - # read the alm from the extension - raw = fits[ext].read() - alm = np.empty(len(raw), dtype=complex) - alm.real = raw["real"] - alm.imag = raw["imag"] - del raw - - # read and attach metadata - alm.dtype = np.dtype(alm.dtype, metadata=_read_metadata(fits[ext])) - - # store in set of alms - alms[n, i] = alm + # read the alm from the extension and store in set of alms + alms[n, i] = _read_complex(fits, ext) logger.info("done with %d alms", len(alms)) @@ -428,14 +455,8 @@ def write_cls(filename, cls, *, clobber=False, workdir=".", include=None, exclud ext = f"CL{cln}" cln += 1 - # get the data into structured format if not already - cl = _as_twopoint(cl, "CL") - - # write the data columns - fits.write_table(cl, extname=ext) - - # write the metadata - _write_metadata(fits[ext], cl.dtype.metadata) + # write the data in structured format + _write_twopoint(fits, ext, cl, "CL") # write the TOC entry tocentry[0] = (ext, k1, k2, i1, i2) @@ -470,14 +491,8 @@ def read_cls(filename, workdir=".", *, include=None, exclude=None): logger.info("reading %s x %s cl for bins %s, %s", k1, k2, i1, i2) - # read the cl from the extension - cl = fits[ext].read() - - # read and attach metadata - cl.dtype = np.dtype(cl.dtype, metadata=_read_metadata(fits[ext])) - - # store in set of cls - cls[k1, k2, i1, i2] = cl + # read the cl from the extension and store in set of cls + cls[k1, k2, i1, i2] = _read_twopoint(fits, ext) logger.info("done with %d cls", len(cls)) @@ -533,14 +548,8 @@ def write_mms(filename, mms, *, clobber=False, workdir=".", include=None, exclud ext = f"MM{mmn}" mmn += 1 - # get the data into structured format if not already - mm = _as_twopoint(mm, "MM") - - # write the data columns - fits.write_table(mm, extname=ext) - - # write the metadata - _write_metadata(fits[ext], mm.dtype.metadata) + # write the data in structured format + _write_twopoint(fits, ext, mm, "MM") # write the TOC entry tocentry[0] = (ext, n, i1, i2) @@ -575,14 +584,8 @@ def read_mms(filename, workdir=".", *, include=None, exclude=None): logger.info("reading mixing matrix %s for bins %s, %s", n, i1, i2) - # read the mixing matrix from the extension - mm = fits[ext].read() - - # read and attach metadata - mm.dtype = np.dtype(mm.dtype, metadata=_read_metadata(fits[ext])) - - # store in set of mms - mms[n, i1, i2] = mm + # read the mixing matrix from the extension and store in set of mms + mms[n, i1, i2] = _read_twopoint(fits, ext) logger.info("done with %d mm(s)", len(mms)) @@ -711,3 +714,164 @@ def read_cov(filename, workdir=".", *, include=None, exclude=None): # return the toc dict of covariances return cov + + +class TocFits(MutableMapping): + """A FITS-backed TocDict.""" + + tag = "EXT" + """Tag for FITS extensions.""" + + columns = {} + """Columns and their formats in the FITS table of contents.""" + + @staticmethod + def reader(fits, ext): + """Read data from FITS extension.""" + return fits[ext].read() + + @staticmethod + def writer(fits, ext, data): + """Write data to FITS extension.""" + if data.dtype.names is None: + msg = "data must be structured array" + raise TypeError(msg) + fits.write_table(data, extname=ext) + + @property + def fits(self): + """Return an opened FITS context manager.""" + return fitsio.FITS(self.path, mode="rw", clobber=False) + + @property + def toc(self): + """Return a view of the FITS table of contents.""" + return MappingProxyType(self._toc) + + def __init__(self, path, *, clobber=False): + self.path = Path(path) + + # FITS extension for table of contents + self.ext = f"{self.tag.upper()}TOC" + + # if new or overwriting, create an empty FITS with primary HDU + if not self.path.exists() or clobber: + with fitsio.FITS(self.path, mode="rw", clobber=True) as fits: + fits.write(None) + + # reopen FITS for writing data + with self.fits as fits: + # write a new ToC extension if FITS doesn't already contain one + if self.ext not in fits: + fits.create_table_hdu( + names=["EXT", *self.columns.keys()], + formats=["10A", *self.columns.values()], + extname=self.ext, + ) + + # get the dtype for ToC entries + self.dtype = fits[self.ext].get_rec_dtype()[0] + + # empty ToC + self._toc = TocDict() + else: + # read existing ToC from FITS + toc = fits[self.ext].read() + + # store the dtype for ToC entries + toc.dtype = toc.dtype + + # store the ToC as a mapping + self._toc = TocDict({tuple(key): str(ext) for ext, *key in toc}) + + # set up a weakly-referenced cache for extension data + self._cache = WeakValueDictionary() + + def __len__(self): + return len(self._toc) + + def __iter__(self): + return iter(self._toc) + + def __contains__(self, key): + if not isinstance(key, tuple): + key = (key,) + return key in self._toc + + def __getitem__(self, key): + ext = self._toc[key] + + # if a TocDict is returned, we have the result of a selection + if isinstance(ext, TocDict): + # make a new instance and copy attributes + selected = object.__new__(self.__class__) + selected.path = self.path + # shared cache since both instances read the same file + selected._cache = self._cache + # the new toc contains the result of the selection + selected._toc = ext + return selected + + # a specific extension was requested, fetch data + data = self._cache.get(ext) + if data is None: + with self.fits as fits: + data = self.reader(fits, ext) + self._cache[ext] = data + return data + + def __setitem__(self, key, value): + # keys are always tuples + if not isinstance(key, tuple): + key = (key,) + + # check if an extension with the given key already exists + # otherwise, get the first free extension with the given tag + if key in self._toc: + ext = self._toc[key] + else: + extn = len(self._toc) + exts = set(self._toc.values()) + while (ext := f"{self.tag.upper()}{extn}") in exts: + extn += 1 + + # write data using the class writer, and update ToC as necessary + with self.fits as fits: + self.writer(fits, ext, value) + if key not in self._toc: + tocentry = np.empty(1, dtype=self.dtype) + tocentry[0] = (ext, *key) + fits[self.ext].append(tocentry) + self._toc[key] = ext + + def __delitem__(self, key): + # fitsio does not support deletion of extensions + msg = "deleting FITS extensions is not supported" + raise NotImplementedError(msg) + + +class AlmFits(TocFits): + """FITS-backed mapping for alms.""" + + tag = "ALM" + columns = {"NAME": "10A", "BIN": "I"} + reader = staticmethod(_read_complex) + writer = staticmethod(_write_complex) + + +class ClsFits(TocFits): + """FITS-backed mapping for cls.""" + + tag = "CL" + columns = {"NAME1": "10A", "NAME2": "10A", "BIN1": "I", "BIN2": "I"} + reader = staticmethod(_read_twopoint) + writer = partial(_write_twopoint, name=tag) + + +class MmsFits(TocFits): + """FITS-backed mapping for mixing matrices.""" + + tag = "MM" + columns = {"NAME": "10A", "BIN1": "I", "BIN2": "I"} + reader = staticmethod(_read_twopoint) + writer = partial(_write_twopoint, name=tag) diff --git a/heracles/twopoint.py b/heracles/twopoint.py index 3235079..be44f09 100644 --- a/heracles/twopoint.py +++ b/heracles/twopoint.py @@ -41,7 +41,15 @@ logger = logging.getLogger(__name__) -def angular_power_spectra(alms, alms2=None, *, lmax=None, include=None, exclude=None): +def angular_power_spectra( + alms, + alms2=None, + *, + lmax=None, + include=None, + exclude=None, + out=None, +): """compute angular power spectra from a set of alms""" logger.info( @@ -55,25 +63,34 @@ def angular_power_spectra(alms, alms2=None, *, lmax=None, include=None, exclude= # collect all alm combinations for computing cls if alms2 is None: - alm_pairs = combinations_with_replacement(alms.items(), 2) + pairs = combinations_with_replacement(alms, 2) + alms2 = alms else: - alm_pairs = product(alms.items(), alms2.items()) + pairs = product(alms, alms2) # keep track of the twopoint combinations we have seen here twopoint_names = set() + # output tocdict, use given or empty + if out is None: + cls = TocDict() + else: + cls = out + # compute cls for all alm pairs # do not compute duplicates - cls = TocDict() - for ((k1, i1), alm1), ((k2, i2), alm2) in alm_pairs: + for (k1, i1), (k2, i2) in pairs: + # skip duplicate cls in any order + if (k1, k2, i1, i2) in cls or (k2, k1, i2, i1) in cls: + continue + # get the two-point code in standard order if (k1, k2) not in twopoint_names and (k2, k1) in twopoint_names: i1, i2 = i2, i1 k1, k2 = k2, k1 - - # skip duplicate cls in any order - if (k1, k2, i1, i2) in cls or (k2, k1, i2, i1) in cls: - continue + swapped = True + else: + swapped = False # check if cl is skipped by explicit include or exclude list if not toc_match((k1, k2, i1, i2), include, exclude): @@ -81,6 +98,12 @@ def angular_power_spectra(alms, alms2=None, *, lmax=None, include=None, exclude= logger.info("computing %s x %s cl for bins %s, %s", k1, k2, i1, i2) + # retrieve alms from keys; make sure swap is respected + # this is done only now because alms might lazy-load from file + alm1, alm2 = alms[k1, i1], alms2[k2, i2] + if swapped: + alm1, alm2 = alm2, alm1 + # compute the raw cl from the alms cl = hp.alm2cl(alm1, alm2, lmax_out=lmax) diff --git a/tests/test_io.py b/tests/test_io.py index 4c64702..46f4611 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -375,3 +375,115 @@ def test_read_mask_extra( extra_mask_name=mock_writemask_extra, ) assert (mask == maps[:, ibin] * maps_extra[:]).all() + + +def test_tocfits(tmp_path): + import fitsio + import numpy as np + + from heracles.io import TocFits + + class TestFits(TocFits): + tag = "test" + columns = {"col1": "I", "col2": "J"} + + path = tmp_path / "test.fits" + + assert not path.exists() + + tocfits = TestFits(path, clobber=True) + + assert path.exists() + + with fitsio.FITS(path) as fits: + assert len(fits) == 2 + toc = fits["TESTTOC"].read() + assert toc.dtype.names == ("EXT", "col1", "col2") + assert len(toc) == 0 + + assert len(tocfits) == 0 + assert list(tocfits) == [] + assert tocfits.toc == {} + + data12 = np.zeros(5, dtype=[("X", float), ("Y", int)]) + data22 = np.ones(5, dtype=[("X", float), ("Y", int)]) + + tocfits[1, 2] = data12 + + with fitsio.FITS(path) as fits: + assert len(fits) == 3 + toc = fits["TESTTOC"].read() + assert len(toc) == 1 + np.testing.assert_array_equal(fits["TEST0"].read(), data12) + + assert len(tocfits) == 1 + assert list(tocfits) == [(1, 2)] + assert tocfits.toc == {(1, 2): "TEST0"} + np.testing.assert_array_equal(tocfits[1, 2], data12) + + tocfits[2, 2] = data22 + + with fitsio.FITS(path) as fits: + assert len(fits) == 4 + toc = fits["TESTTOC"].read() + assert len(toc) == 2 + np.testing.assert_array_equal(fits["TEST0"].read(), data12) + np.testing.assert_array_equal(fits["TEST1"].read(), data22) + + assert len(tocfits) == 2 + assert list(tocfits) == [(1, 2), (2, 2)] + assert tocfits.toc == {(1, 2): "TEST0", (2, 2): "TEST1"} + np.testing.assert_array_equal(tocfits[1, 2], data12) + np.testing.assert_array_equal(tocfits[2, 2], data22) + + with pytest.raises(NotImplementedError): + del tocfits[1, 2] + + del tocfits + + tocfits2 = TestFits(path, clobber=False) + + assert len(tocfits2) == 2 + assert list(tocfits2) == [(1, 2), (2, 2)] + assert tocfits2.toc == {(1, 2): "TEST0", (2, 2): "TEST1"} + np.testing.assert_array_equal(tocfits2[1, 2], data12) + np.testing.assert_array_equal(tocfits2[2, 2], data22) + + +def test_tocfits_is_lazy(tmp_path): + import fitsio + + from heracles.io import TocFits + + path = tmp_path / "bad.fits" + + # test keys(), values(), and items() are not eagerly reading data + tocfits = TocFits(path, clobber=True) + + # manually enter some non-existent rows into the ToC + assert tocfits._toc == {} + tocfits._toc[0,] = "BAD0" + tocfits._toc[1,] = "BAD1" + tocfits._toc[2,] = "BAD2" + + # these should not error + tocfits.keys() + tocfits.values() + tocfits.items() + + # contains and iteration are lazy + assert 0 in tocfits + assert list(tocfits) == [(0,), (1,), (2,)] + + # subselection should work fine + selected = tocfits[...] + assert isinstance(selected, TocFits) + assert len(selected) == 3 + + # make sure nothing is in the FITS + with fitsio.FITS(path, "r") as fits: + assert len(fits) == 2 + + # make sure there are errors when acualising the generators + with pytest.raises(OSError): + list(tocfits.values())