Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH(io): TocFits class for FITS-backed tocdicts #53

Merged
merged 4 commits into from
Nov 11, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
264 changes: 214 additions & 50 deletions heracles/io.py
Original file line number Diff line number Diff line change
@@ -20,6 +20,11 @@

import logging
import os
from collections.abc import MutableMapping
from functools import partial
from pathlib import Path
from types import MappingProxyType
from weakref import WeakValueDictionary

import fitsio
import healpy as hp
@@ -61,7 +66,29 @@ def _read_metadata(hdu):
return md


def _as_twopoint(arr, name):
def _write_complex(fits, ext, arr):
"""write complex-valued data to FITS table"""
# write the data
fits.write_table([arr.real, arr.imag], names=["real", "imag"], extname=ext)

# write the metadata
_write_metadata(fits[ext], arr.dtype.metadata)


def _read_complex(fits, ext):
"""read complex-valued data from FITS table"""
# read structured data as complex array
raw = fits[ext].read()
arr = np.empty(len(raw), dtype=complex)
arr.real = raw["real"]
arr.imag = raw["imag"]
del raw
# read and attach metadata
arr.dtype = np.dtype(arr.dtype, metadata=_read_metadata(fits[ext]))
return arr


def _write_twopoint(fits, ext, arr, name):
"""convert two-point data (i.e. one L column) to structured array"""

arr = np.asanyarray(arr)
@@ -89,6 +116,19 @@ def _as_twopoint(arr, name):
arr["LMAX"] = arr["L"] + 1
arr["W"] = 1

# write the twopoint data
fits.write_table(arr, extname=ext)

# write the metadata
_write_metadata(fits[ext], arr.dtype.metadata)


def _read_twopoint(fits, ext):
"""read two-point data from FITS"""
# read data from extension
arr = fits[ext].read()
# read and attach metadata
arr.dtype = np.dtype(arr.dtype, metadata=_read_metadata(fits[ext]))
return arr


@@ -322,11 +362,8 @@ def write_alms(
ext = f"ALM{almn}"
almn += 1

# write the data
fits.write_table([alm.real, alm.imag], names=["real", "imag"], extname=ext)

# write the metadata
_write_metadata(fits[ext], alm.dtype.metadata)
# write the alm as structured data with metadata
_write_complex(fits, ext, alm)

# write the TOC entry
tocentry[0] = (ext, n, i)
@@ -361,18 +398,8 @@ def read_alms(filename, workdir=".", *, include=None, exclude=None):

logger.info("reading %s alm for bin %s", n, i)

# read the alm from the extension
raw = fits[ext].read()
alm = np.empty(len(raw), dtype=complex)
alm.real = raw["real"]
alm.imag = raw["imag"]
del raw

# read and attach metadata
alm.dtype = np.dtype(alm.dtype, metadata=_read_metadata(fits[ext]))

# store in set of alms
alms[n, i] = alm
# read the alm from the extension and store in set of alms
alms[n, i] = _read_complex(fits, ext)

logger.info("done with %d alms", len(alms))

@@ -428,14 +455,8 @@ def write_cls(filename, cls, *, clobber=False, workdir=".", include=None, exclud
ext = f"CL{cln}"
cln += 1

# get the data into structured format if not already
cl = _as_twopoint(cl, "CL")

# write the data columns
fits.write_table(cl, extname=ext)

# write the metadata
_write_metadata(fits[ext], cl.dtype.metadata)
# write the data in structured format
_write_twopoint(fits, ext, cl, "CL")

# write the TOC entry
tocentry[0] = (ext, k1, k2, i1, i2)
@@ -470,14 +491,8 @@ def read_cls(filename, workdir=".", *, include=None, exclude=None):

logger.info("reading %s x %s cl for bins %s, %s", k1, k2, i1, i2)

# read the cl from the extension
cl = fits[ext].read()

# read and attach metadata
cl.dtype = np.dtype(cl.dtype, metadata=_read_metadata(fits[ext]))

# store in set of cls
cls[k1, k2, i1, i2] = cl
# read the cl from the extension and store in set of cls
cls[k1, k2, i1, i2] = _read_twopoint(fits, ext)

logger.info("done with %d cls", len(cls))

@@ -533,14 +548,8 @@ def write_mms(filename, mms, *, clobber=False, workdir=".", include=None, exclud
ext = f"MM{mmn}"
mmn += 1

# get the data into structured format if not already
mm = _as_twopoint(mm, "MM")

# write the data columns
fits.write_table(mm, extname=ext)

# write the metadata
_write_metadata(fits[ext], mm.dtype.metadata)
# write the data in structured format
_write_twopoint(fits, ext, mm, "MM")

# write the TOC entry
tocentry[0] = (ext, n, i1, i2)
@@ -575,14 +584,8 @@ def read_mms(filename, workdir=".", *, include=None, exclude=None):

logger.info("reading mixing matrix %s for bins %s, %s", n, i1, i2)

# read the mixing matrix from the extension
mm = fits[ext].read()

# read and attach metadata
mm.dtype = np.dtype(mm.dtype, metadata=_read_metadata(fits[ext]))

# store in set of mms
mms[n, i1, i2] = mm
# read the mixing matrix from the extension and store in set of mms
mms[n, i1, i2] = _read_twopoint(fits, ext)

logger.info("done with %d mm(s)", len(mms))

@@ -711,3 +714,164 @@ def read_cov(filename, workdir=".", *, include=None, exclude=None):

# return the toc dict of covariances
return cov


class TocFits(MutableMapping):
"""A FITS-backed TocDict."""

tag = "EXT"
"""Tag for FITS extensions."""

columns = {}
"""Columns and their formats in the FITS table of contents."""

@staticmethod
def reader(fits, ext):
"""Read data from FITS extension."""
return fits[ext].read()

@staticmethod
def writer(fits, ext, data):
"""Write data to FITS extension."""
if data.dtype.names is None:
msg = "data must be structured array"
raise TypeError(msg)
fits.write_table(data, extname=ext)

@property
def fits(self):
"""Return an opened FITS context manager."""
return fitsio.FITS(self.path, mode="rw", clobber=False)

@property
def toc(self):
"""Return a view of the FITS table of contents."""
return MappingProxyType(self._toc)

def __init__(self, path, *, clobber=False):
self.path = Path(path)

# FITS extension for table of contents
self.ext = f"{self.tag.upper()}TOC"

# if new or overwriting, create an empty FITS with primary HDU
if not self.path.exists() or clobber:
with fitsio.FITS(self.path, mode="rw", clobber=True) as fits:
fits.write(None)

# reopen FITS for writing data
with self.fits as fits:
# write a new ToC extension if FITS doesn't already contain one
if self.ext not in fits:
fits.create_table_hdu(
names=["EXT", *self.columns.keys()],
formats=["10A", *self.columns.values()],
extname=self.ext,
)

# get the dtype for ToC entries
self.dtype = fits[self.ext].get_rec_dtype()[0]

# empty ToC
self._toc = TocDict()
else:
# read existing ToC from FITS
toc = fits[self.ext].read()

# store the dtype for ToC entries
toc.dtype = toc.dtype

# store the ToC as a mapping
self._toc = TocDict({tuple(key): str(ext) for ext, *key in toc})

# set up a weakly-referenced cache for extension data
self._cache = WeakValueDictionary()

def __len__(self):
return len(self._toc)

def __iter__(self):
return iter(self._toc)

def __contains__(self, key):
if not isinstance(key, tuple):
key = (key,)
return key in self._toc

def __getitem__(self, key):
ext = self._toc[key]

# if a TocDict is returned, we have the result of a selection
if isinstance(ext, TocDict):
# make a new instance and copy attributes
selected = object.__new__(self.__class__)
selected.path = self.path
# shared cache since both instances read the same file
selected._cache = self._cache
# the new toc contains the result of the selection
selected._toc = ext
return selected

# a specific extension was requested, fetch data
data = self._cache.get(ext)
if data is None:
with self.fits as fits:
data = self.reader(fits, ext)
self._cache[ext] = data
return data

def __setitem__(self, key, value):
# keys are always tuples
if not isinstance(key, tuple):
key = (key,)

# check if an extension with the given key already exists
# otherwise, get the first free extension with the given tag
if key in self._toc:
ext = self._toc[key]
else:
extn = len(self._toc)
exts = set(self._toc.values())
while (ext := f"{self.tag.upper()}{extn}") in exts:
extn += 1

# write data using the class writer, and update ToC as necessary
with self.fits as fits:
self.writer(fits, ext, value)
if key not in self._toc:
tocentry = np.empty(1, dtype=self.dtype)
tocentry[0] = (ext, *key)
fits[self.ext].append(tocentry)
self._toc[key] = ext

def __delitem__(self, key):
# fitsio does not support deletion of extensions
msg = "deleting FITS extensions is not supported"
raise NotImplementedError(msg)


class AlmFits(TocFits):
"""FITS-backed mapping for alms."""

tag = "ALM"
columns = {"NAME": "10A", "BIN": "I"}
reader = staticmethod(_read_complex)
writer = staticmethod(_write_complex)


class ClsFits(TocFits):
"""FITS-backed mapping for cls."""

tag = "CL"
columns = {"NAME1": "10A", "NAME2": "10A", "BIN1": "I", "BIN2": "I"}
reader = staticmethod(_read_twopoint)
writer = partial(_write_twopoint, name=tag)


class MmsFits(TocFits):
"""FITS-backed mapping for mixing matrices."""

tag = "MM"
columns = {"NAME": "10A", "BIN1": "I", "BIN2": "I"}
reader = staticmethod(_read_twopoint)
writer = partial(_write_twopoint, name=tag)
41 changes: 32 additions & 9 deletions heracles/twopoint.py
Original file line number Diff line number Diff line change
@@ -41,7 +41,15 @@
logger = logging.getLogger(__name__)


def angular_power_spectra(alms, alms2=None, *, lmax=None, include=None, exclude=None):
def angular_power_spectra(
alms,
alms2=None,
*,
lmax=None,
include=None,
exclude=None,
out=None,
):
"""compute angular power spectra from a set of alms"""

logger.info(
@@ -55,32 +63,47 @@ def angular_power_spectra(alms, alms2=None, *, lmax=None, include=None, exclude=

# collect all alm combinations for computing cls
if alms2 is None:
alm_pairs = combinations_with_replacement(alms.items(), 2)
pairs = combinations_with_replacement(alms, 2)
alms2 = alms
else:
alm_pairs = product(alms.items(), alms2.items())
pairs = product(alms, alms2)

# keep track of the twopoint combinations we have seen here
twopoint_names = set()

# output tocdict, use given or empty
if out is None:
cls = TocDict()
else:
cls = out

# compute cls for all alm pairs
# do not compute duplicates
cls = TocDict()
for ((k1, i1), alm1), ((k2, i2), alm2) in alm_pairs:
for (k1, i1), (k2, i2) in pairs:
# skip duplicate cls in any order
if (k1, k2, i1, i2) in cls or (k2, k1, i2, i1) in cls:
continue

# get the two-point code in standard order
if (k1, k2) not in twopoint_names and (k2, k1) in twopoint_names:
i1, i2 = i2, i1
k1, k2 = k2, k1

# skip duplicate cls in any order
if (k1, k2, i1, i2) in cls or (k2, k1, i2, i1) in cls:
continue
swapped = True
else:
swapped = False

# check if cl is skipped by explicit include or exclude list
if not toc_match((k1, k2, i1, i2), include, exclude):
continue

logger.info("computing %s x %s cl for bins %s, %s", k1, k2, i1, i2)

# retrieve alms from keys; make sure swap is respected
# this is done only now because alms might lazy-load from file
alm1, alm2 = alms[k1, i1], alms2[k2, i2]
if swapped:
alm1, alm2 = alm2, alm1

# compute the raw cl from the alms
cl = hp.alm2cl(alm1, alm2, lmax_out=lmax)

112 changes: 112 additions & 0 deletions tests/test_io.py
Original file line number Diff line number Diff line change
@@ -375,3 +375,115 @@ def test_read_mask_extra(
extra_mask_name=mock_writemask_extra,
)
assert (mask == maps[:, ibin] * maps_extra[:]).all()


def test_tocfits(tmp_path):
import fitsio
import numpy as np

from heracles.io import TocFits

class TestFits(TocFits):
tag = "test"
columns = {"col1": "I", "col2": "J"}

path = tmp_path / "test.fits"

assert not path.exists()

tocfits = TestFits(path, clobber=True)

assert path.exists()

with fitsio.FITS(path) as fits:
assert len(fits) == 2
toc = fits["TESTTOC"].read()
assert toc.dtype.names == ("EXT", "col1", "col2")
assert len(toc) == 0

assert len(tocfits) == 0
assert list(tocfits) == []
assert tocfits.toc == {}

data12 = np.zeros(5, dtype=[("X", float), ("Y", int)])
data22 = np.ones(5, dtype=[("X", float), ("Y", int)])

tocfits[1, 2] = data12

with fitsio.FITS(path) as fits:
assert len(fits) == 3
toc = fits["TESTTOC"].read()
assert len(toc) == 1
np.testing.assert_array_equal(fits["TEST0"].read(), data12)

assert len(tocfits) == 1
assert list(tocfits) == [(1, 2)]
assert tocfits.toc == {(1, 2): "TEST0"}
np.testing.assert_array_equal(tocfits[1, 2], data12)

tocfits[2, 2] = data22

with fitsio.FITS(path) as fits:
assert len(fits) == 4
toc = fits["TESTTOC"].read()
assert len(toc) == 2
np.testing.assert_array_equal(fits["TEST0"].read(), data12)
np.testing.assert_array_equal(fits["TEST1"].read(), data22)

assert len(tocfits) == 2
assert list(tocfits) == [(1, 2), (2, 2)]
assert tocfits.toc == {(1, 2): "TEST0", (2, 2): "TEST1"}
np.testing.assert_array_equal(tocfits[1, 2], data12)
np.testing.assert_array_equal(tocfits[2, 2], data22)

with pytest.raises(NotImplementedError):
del tocfits[1, 2]

del tocfits

tocfits2 = TestFits(path, clobber=False)

assert len(tocfits2) == 2
assert list(tocfits2) == [(1, 2), (2, 2)]
assert tocfits2.toc == {(1, 2): "TEST0", (2, 2): "TEST1"}
np.testing.assert_array_equal(tocfits2[1, 2], data12)
np.testing.assert_array_equal(tocfits2[2, 2], data22)


def test_tocfits_is_lazy(tmp_path):
import fitsio

from heracles.io import TocFits

path = tmp_path / "bad.fits"

# test keys(), values(), and items() are not eagerly reading data
tocfits = TocFits(path, clobber=True)

# manually enter some non-existent rows into the ToC
assert tocfits._toc == {}
tocfits._toc[0,] = "BAD0"
tocfits._toc[1,] = "BAD1"
tocfits._toc[2,] = "BAD2"

# these should not error
tocfits.keys()
tocfits.values()
tocfits.items()

# contains and iteration are lazy
assert 0 in tocfits
assert list(tocfits) == [(0,), (1,), (2,)]

# subselection should work fine
selected = tocfits[...]
assert isinstance(selected, TocFits)
assert len(selected) == 3

# make sure nothing is in the FITS
with fitsio.FITS(path, "r") as fits:
assert len(fits) == 2

# make sure there are errors when acualising the generators
with pytest.raises(OSError):
list(tocfits.values())