Skip to content

Commit

Permalink
gh-207: simplified output format (#211)
Browse files Browse the repository at this point in the history
  • Loading branch information
ntessore authored Dec 17, 2024
1 parent 9fda271 commit 38903fd
Show file tree
Hide file tree
Showing 9 changed files with 900 additions and 736 deletions.
55 changes: 35 additions & 20 deletions examples/discrete.ipynb

Large diffs are not rendered by default.

242 changes: 138 additions & 104 deletions examples/example.ipynb

Large diffs are not rendered by default.

26 changes: 12 additions & 14 deletions heracles/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,13 @@
"Visibility",
"Weights",
# io
"read",
"read_vmap",
"read_alms",
"read_cls",
"read_maps",
"read_mms",
"write",
"write_alms",
"write_cls",
"write_maps",
"write_mms",
# mapper
"Mapper",
# mapping
Expand All @@ -64,13 +62,13 @@
# progress
"NoProgress",
"Progress",
# result
"Result",
"binned",
# twopoint
"angular_power_spectra",
"debias_cls",
"mixing_matrices",
"bin2pt",
"binned_cls",
"binned_mms",
]

try:
Expand Down Expand Up @@ -108,15 +106,13 @@
)

from .io import (
read,
read_vmap,
read_alms,
read_cls,
read_maps,
read_mms,
write,
write_alms,
write_cls,
write_maps,
write_mms,
)

from .mapper import (
Expand All @@ -133,11 +129,13 @@
Progress,
)

from .result import (
Result,
binned,
)

from .twopoint import (
angular_power_spectra,
debias_cls,
mixing_matrices,
bin2pt,
binned_cls,
binned_mms,
)
254 changes: 112 additions & 142 deletions heracles/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import os
import re
from collections.abc import MutableMapping, Sequence
from functools import partial
from pathlib import Path
from types import MappingProxyType
from typing import TYPE_CHECKING, Union
Expand All @@ -33,6 +32,7 @@
import numpy as np

from .core import TocDict, toc_match
from .result import Result

if TYPE_CHECKING:
from typing import TypeAlias
Expand Down Expand Up @@ -264,51 +264,94 @@ def _read_complex(hdu):
return arr


def _write_twopoint(fits, ext, key, arr, name):
"""convert two-point data (i.e. one L column) to structured array"""

arr = np.asanyarray(arr)
def _write_result(fits, ext, key, result):
"""
Write a result array to FITS.
"""

# get the data into structured array if not already
if arr.dtype.names is None:
n, *dims = arr.shape
data = arr
# keep ndarray subclasses or we would lose all Result attributes
result = np.asanyarray(result)

# get ell axis
axis = getattr(result, "axis", result.ndim - 1)

# get data & move ell axis to front
data = np.moveaxis(result, axis, 0)

# get ell values or create default
ell = getattr(result, "ell", None)
if ell is None:
ell = np.arange(data.shape[0])

# get lower bounds or create default
lower = getattr(result, "lower", None)
if lower is None:
lower = ell

# get upper array bounds or create default
upper = getattr(result, "upper", None)
if upper is None:
upper = np.append(ell[1:], ell[-1] + 1)

# get weight array or create default
weight = getattr(result, "weight", None)
if weight is None:
weight = np.ones(data.shape[0])

# write the result as columnar data
fits.write_table(
[
data,
ell,
lower,
upper,
weight,
],
names=[
"ARRAY",
"ELL",
"LOWER",
"UPPER",
"WEIGHT",
],
extname=ext,
header=[
dict(name="ELLAXIS", value=1, comment="number of angular axes"),
dict(name="ELLAXIS1", value=axis, comment="index of angular axis 1"),
],
)

dt = np.dtype(
[
("L", float),
(name, arr.dtype.str, dims) if dims else (name, arr.dtype.str),
("LMIN", float),
("LMAX", float),
("W", float),
],
metadata=dict(arr.dtype.metadata or {}),
)
# write the metadata
_write_metadata(fits[ext], result.dtype.metadata)

arr = np.empty(n, dt)
arr["L"] = np.arange(n)
arr[name] = data
arr["LMIN"] = arr["L"]
arr["LMAX"] = arr["L"] + 1
arr["W"] = 1

# write the twopoint data
fits.write_table(arr, extname=ext)
def _read_result(hdu):
"""
Read a result array from FITS.
"""

# write the key
_write_key(fits[ext], key)
# read columnar data from extension
data = hdu.read()
h = hdu.read_header()

# write the metadata
_write_metadata(fits[ext], arr.dtype.metadata)
# the angular axis
elldim = h["ELLAXIS"]
if elldim != 1:
raise NotImplementedError("multiple angular axes are not supported")
axis = tuple(h[f"ELLAXIS{i}"] for i in range(1, elldim + 1))

# get data array and move axis back to right position
result = np.moveaxis(data["ARRAY"], tuple(range(elldim)), axis)

def _read_twopoint(hdu):
"""read two-point data from FITS"""
# read data from extension
arr = hdu.read()
# read and attach metadata
arr.dtype = np.dtype(arr.dtype, metadata=_read_metadata(hdu))
return arr
# construct result array with ancillary arrays and metadata
return Result(
result,
axis=axis[0] if elldim == 1 else axis,
ell=data["ELL"],
lower=data["LOWER"],
upper=data["UPPER"],
weight=data["WEIGHT"],
).view(np.dtype(result.dtype, metadata=_read_metadata(hdu)))


def read_vmap(filename, nside=None, field=0, *, transform=False, lmax=None):
Expand Down Expand Up @@ -467,76 +510,16 @@ def read_alms(filename, workdir=".", *, include=None, exclude=None):
return alms


def write_cls(filename, cls, *, clobber=False, workdir=".", include=None, exclude=None):
"""write a set of cls to FITS file
If the output file exists, the new estimates will be appended, unless the
``clobber`` parameter is set to ``True``.
def write(path, results, *, clobber=False):
"""
Write a set of results to FITS file.
logger.info("writing %d cls to %s", len(cls), filename)

# full path to FITS file
path = os.path.join(workdir, filename)

# if new or overwriting, create an empty FITS with primary HDU
if not os.path.isfile(path) or clobber:
with fitsio.FITS(path, mode="rw", clobber=True) as fits:
fits.write(None)

# reopen FITS for writing data
with fitsio.FITS(path, mode="rw", clobber=False) as fits:
for key, cl in cls.items():
# skip if not selected
if not toc_match(key, include=include, exclude=exclude):
continue

logger.info("writing cl %s", key)

# extension name
ext = _get_next_extname(fits, "CL")

# write the data in structured format
_write_twopoint(fits, ext, key, cl, "CL")

logger.info("done with %d cls", len(cls))


def read_cls(filename, workdir=".", *, include=None, exclude=None):
"""read a set of cls from a FITS file"""

logger.info("reading cls from %s", filename)

# full path to FITS file
path = os.path.join(workdir, filename)

# the returned set of cls
cls = TocDict()

# iterate over valid HDUs in the file
for key, hdu in _iterfits(path, "CL", include=include, exclude=exclude):
logger.info("reading cl %s", key)
cls[key] = _read_twopoint(hdu)

logger.info("done with %d cls", len(cls))

# return the dictionary of cls
return cls


def write_mms(filename, mms, *, clobber=False, workdir=".", include=None, exclude=None):
"""write a set of mixing matrices to FITS file
If the output file exists, the new mixing matrices will be appended, unless
If the output file exists, the new results will be appended, unless
the ``clobber`` parameter is set to ``True``.
"""

logger.info("writing %d mm(s) to %s", len(mms), filename)

# full path to FITS file
path = os.path.join(workdir, filename)
logger.info("writing %d results to %s", len(results), path)

# if new or overwriting, create an empty FITS with primary HDU
if not os.path.isfile(path) or clobber:
Expand All @@ -545,42 +528,45 @@ def write_mms(filename, mms, *, clobber=False, workdir=".", include=None, exclud

# reopen FITS for writing data
with fitsio.FITS(path, mode="rw", clobber=False) as fits:
for key, mm in mms.items():
# skip if not selected
if not toc_match(key, include=include, exclude=exclude):
continue

logger.info("writing mm %s", key)
for key, result in results.items():
logger.info("writing result %s", key)

# extension name
ext = _get_next_extname(fits, "MM")
ext = _string_from_key(key)

# write the data in structured format
_write_twopoint(fits, ext, key, mm, "MM")

logger.info("done with %d mm(s)", len(mms))
_write_result(fits, ext, key, result)

logger.info("done with %d results", len(results))

def read_mms(filename, workdir=".", *, include=None, exclude=None):
"""read a set of mixing matrices from a FITS file"""

logger.info("reading mixing matrices from %s", filename)
def read(path):
"""
Read a set of results from a FITS file.
"""

# full path to FITS file
path = os.path.join(workdir, filename)
logger.info("reading results from %s", path)

# the returned set of mms
mms = TocDict()
# the returned set of cls
results = {}

# iterate over valid HDUs in the file
for key, hdu in _iterfits(path, "MM", include=include, exclude=exclude):
logger.info("writing mm %s", key)
mms[key] = _read_twopoint(hdu)
# read all HDUs in file into dict keys
with fitsio.FITS(path) as fits:
for hdu in fits:
if not hdu.has_data():
continue
ext = hdu.get_extname()
if not ext:
continue
key = _key_from_string(ext)
if not key:
continue
logger.info("reading result %s", key)
results[key] = _read_result(hdu)

logger.info("done with %d mm(s)", len(mms))
logger.info("done with %d results", len(results))

# return the dictionary of mms
return mms
return results


def write_cov(filename, cov, clobber=False, workdir=".", include=None, exclude=None):
Expand Down Expand Up @@ -781,19 +767,3 @@ class AlmFits(TocFits):
tag = "ALM"
reader = staticmethod(_read_complex)
writer = staticmethod(_write_complex)


class ClsFits(TocFits):
"""FITS-backed mapping for cls."""

tag = "CL"
reader = staticmethod(_read_twopoint)
writer = partial(_write_twopoint, name="CL")


class MmsFits(TocFits):
"""FITS-backed mapping for mixing matrices."""

tag = "MM"
reader = staticmethod(_read_twopoint)
writer = partial(_write_twopoint, name="MM")
Loading

0 comments on commit 38903fd

Please sign in to comment.