Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH(io): write nested tuple keys to FITS #128

Merged
merged 2 commits into from
Apr 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 30 additions & 15 deletions heracles/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@

import logging
import os
import re
from collections.abc import MutableMapping
from functools import partial
from pathlib import Path
from types import MappingProxyType
from typing import TYPE_CHECKING
from warnings import warn
from weakref import WeakValueDictionary

Expand All @@ -32,6 +34,9 @@

from .core import TocDict, toc_match

if TYPE_CHECKING:
from typing import TypeAlias

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -63,21 +68,33 @@
"bias": "additive bias of spectrum",
}

# type for valid keys
_KeyType: "TypeAlias" = "str | int | tuple[_KeyType, ...]"


def _extname_from_key(key):
def _extname_from_key(key: _KeyType) -> str:
"""
Return FITS extension name for a given key.
"""
if not isinstance(key, tuple):
key = (key,)
return ",".join(map(str, key))
if isinstance(key, tuple):
names = list(map(_extname_from_key, key))
c = ";" if any("," in name for name in names) else ","
return c.join(names)
return re.sub(r"\W+", "_", str(key))


def _key_from_extname(ext):
def _key_from_extname(extname: str) -> _KeyType:
"""
Return key for a given FITS extension name.
"""
return tuple(int(s) if s.isdigit() else s for s in ext.split(","))
keys = extname.split(";")
if len(keys) > 1:
return tuple(map(_key_from_extname, keys))
keys = keys[0].split(",")
if len(keys) > 1:
return tuple(map(_key_from_extname, keys))
key = keys[0]
return int(key) if key.isdigit() else key


def _iterfits(path, include=None, exclude=None):
Expand Down Expand Up @@ -525,15 +542,15 @@ def write_cov(filename, cov, clobber=False, workdir=".", include=None, exclude=N

# reopen FITS for writing data
with fitsio.FITS(path, mode="rw", clobber=False) as fits:
for (k1, k2), mat in cov.items():
for key, mat in cov.items():
# skip if not selected
if not toc_match((k1, k2), include=include, exclude=exclude):
if not toc_match(key, include=include, exclude=exclude):
continue

# the cl extension name
ext = _extname_from_key(k1 + k2)
logger.info("writing covariance matrix %s", key)

logger.info("writing %s x %s covariance matrix", k1, k2)
# the cov extension name
ext = _extname_from_key(key)

# write the covariance matrix as an image
fits.write_image(mat, extname=ext)
Expand Down Expand Up @@ -566,9 +583,7 @@ def read_cov(filename, workdir=".", *, include=None, exclude=None):

# iterate over valid HDUs in the file
for key, hdu in _iterfits(path, include=include, exclude=exclude):
k1, k2 = key[: len(key) // 2], key[len(key) // 2 :]

logger.info("reading %s x %s covariance matrix", k1, k2)
logger.info("reading covariance matrix %s", key)

# read the covariance matrix from the extension
mat = hdu.read()
Expand All @@ -577,7 +592,7 @@ def read_cov(filename, workdir=".", *, include=None, exclude=None):
mat.dtype = np.dtype(mat.dtype, metadata=_read_metadata(hdu))

# store in set
cov[k1, k2] = mat
cov[key] = mat

logger.info("done with %d covariance(s)", len(cov))

Expand Down
30 changes: 29 additions & 1 deletion tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,34 @@ def mock_vmap(mock_vmap_fields, nside, datadir):
return filename


def test_extname_from_key():
from heracles.io import _extname_from_key

assert _extname_from_key("a") == "a"
assert _extname_from_key(1) == "1"
assert _extname_from_key(("a",)) == "a"
assert _extname_from_key((1,)) == "1"
assert _extname_from_key(("a", 1)) == "a,1"
assert _extname_from_key(("a", "b", 1, 2)) == "a,b,1,2"
assert _extname_from_key((("a", 1), "b")) == "a,1;b"
assert _extname_from_key((("a", 1), ("b", 2))) == "a,1;b,2"

# test special chars
assert _extname_from_key("a,b,c") == "a_b_c"
assert _extname_from_key("!@#$%^&*()[]{};,.") == "_"


def test_key_from_extname():
from heracles.io import _key_from_extname

assert _key_from_extname("a") == "a"
assert _key_from_extname("1") == 1
assert _key_from_extname("a,1") == ("a", 1)
assert _key_from_extname("a,b,1,2") == ("a", "b", 1, 2)
assert _key_from_extname("a,1;b") == (("a", 1), "b")
assert _key_from_extname("a,1;b,2") == (("a", 1), ("b", 2))


def test_write_read_maps(rng, tmp_path):
import healpy as hp
import numpy as np
Expand Down Expand Up @@ -214,7 +242,7 @@ def test_write_read_cls(mock_cls, tmp_path):

cls = read_cls(filename, workdir=workdir)

assert cls.keys() == mock_cls.keys()
assert list(cls.keys()) == list(mock_cls.keys())
for key in mock_cls:
assert key in cls
cl = cls[key]
Expand Down
Loading