Skip to content

Commit ddd4194

Browse files
authored
ENH(io): write nested tuple keys to FITS (#128)
Keys written to FITS files can now be tuple of tuples (but not deeper). The resulting keys are separated by comma, and then semicolon. Closes: #127
1 parent 9160f6b commit ddd4194

File tree

2 files changed

+59
-16
lines changed

2 files changed

+59
-16
lines changed

heracles/io.py

+30-15
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@
2020

2121
import logging
2222
import os
23+
import re
2324
from collections.abc import MutableMapping
2425
from functools import partial
2526
from pathlib import Path
2627
from types import MappingProxyType
28+
from typing import TYPE_CHECKING
2729
from warnings import warn
2830
from weakref import WeakValueDictionary
2931

@@ -32,6 +34,9 @@
3234

3335
from .core import TocDict, toc_match
3436

37+
if TYPE_CHECKING:
38+
from typing import TypeAlias
39+
3540
logger = logging.getLogger(__name__)
3641

3742

@@ -63,21 +68,33 @@
6368
"bias": "additive bias of spectrum",
6469
}
6570

71+
# type for valid keys
72+
_KeyType: "TypeAlias" = "str | int | tuple[_KeyType, ...]"
73+
6674

67-
def _extname_from_key(key):
75+
def _extname_from_key(key: _KeyType) -> str:
6876
"""
6977
Return FITS extension name for a given key.
7078
"""
71-
if not isinstance(key, tuple):
72-
key = (key,)
73-
return ",".join(map(str, key))
79+
if isinstance(key, tuple):
80+
names = list(map(_extname_from_key, key))
81+
c = ";" if any("," in name for name in names) else ","
82+
return c.join(names)
83+
return re.sub(r"\W+", "_", str(key))
7484

7585

76-
def _key_from_extname(ext):
86+
def _key_from_extname(extname: str) -> _KeyType:
7787
"""
7888
Return key for a given FITS extension name.
7989
"""
80-
return tuple(int(s) if s.isdigit() else s for s in ext.split(","))
90+
keys = extname.split(";")
91+
if len(keys) > 1:
92+
return tuple(map(_key_from_extname, keys))
93+
keys = keys[0].split(",")
94+
if len(keys) > 1:
95+
return tuple(map(_key_from_extname, keys))
96+
key = keys[0]
97+
return int(key) if key.isdigit() else key
8198

8299

83100
def _iterfits(path, include=None, exclude=None):
@@ -525,15 +542,15 @@ def write_cov(filename, cov, clobber=False, workdir=".", include=None, exclude=N
525542

526543
# reopen FITS for writing data
527544
with fitsio.FITS(path, mode="rw", clobber=False) as fits:
528-
for (k1, k2), mat in cov.items():
545+
for key, mat in cov.items():
529546
# skip if not selected
530-
if not toc_match((k1, k2), include=include, exclude=exclude):
547+
if not toc_match(key, include=include, exclude=exclude):
531548
continue
532549

533-
# the cl extension name
534-
ext = _extname_from_key(k1 + k2)
550+
logger.info("writing covariance matrix %s", key)
535551

536-
logger.info("writing %s x %s covariance matrix", k1, k2)
552+
# the cov extension name
553+
ext = _extname_from_key(key)
537554

538555
# write the covariance matrix as an image
539556
fits.write_image(mat, extname=ext)
@@ -566,9 +583,7 @@ def read_cov(filename, workdir=".", *, include=None, exclude=None):
566583

567584
# iterate over valid HDUs in the file
568585
for key, hdu in _iterfits(path, include=include, exclude=exclude):
569-
k1, k2 = key[: len(key) // 2], key[len(key) // 2 :]
570-
571-
logger.info("reading %s x %s covariance matrix", k1, k2)
586+
logger.info("reading covariance matrix %s", key)
572587

573588
# read the covariance matrix from the extension
574589
mat = hdu.read()
@@ -577,7 +592,7 @@ def read_cov(filename, workdir=".", *, include=None, exclude=None):
577592
mat.dtype = np.dtype(mat.dtype, metadata=_read_metadata(hdu))
578593

579594
# store in set
580-
cov[k1, k2] = mat
595+
cov[key] = mat
581596

582597
logger.info("done with %d covariance(s)", len(cov))
583598

tests/test_io.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,34 @@ def mock_vmap(mock_vmap_fields, nside, datadir):
148148
return filename
149149

150150

151+
def test_extname_from_key():
152+
from heracles.io import _extname_from_key
153+
154+
assert _extname_from_key("a") == "a"
155+
assert _extname_from_key(1) == "1"
156+
assert _extname_from_key(("a",)) == "a"
157+
assert _extname_from_key((1,)) == "1"
158+
assert _extname_from_key(("a", 1)) == "a,1"
159+
assert _extname_from_key(("a", "b", 1, 2)) == "a,b,1,2"
160+
assert _extname_from_key((("a", 1), "b")) == "a,1;b"
161+
assert _extname_from_key((("a", 1), ("b", 2))) == "a,1;b,2"
162+
163+
# test special chars
164+
assert _extname_from_key("a,b,c") == "a_b_c"
165+
assert _extname_from_key("!@#$%^&*()[]{};,.") == "_"
166+
167+
168+
def test_key_from_extname():
169+
from heracles.io import _key_from_extname
170+
171+
assert _key_from_extname("a") == "a"
172+
assert _key_from_extname("1") == 1
173+
assert _key_from_extname("a,1") == ("a", 1)
174+
assert _key_from_extname("a,b,1,2") == ("a", "b", 1, 2)
175+
assert _key_from_extname("a,1;b") == (("a", 1), "b")
176+
assert _key_from_extname("a,1;b,2") == (("a", 1), ("b", 2))
177+
178+
151179
def test_write_read_maps(rng, tmp_path):
152180
import healpy as hp
153181
import numpy as np
@@ -214,7 +242,7 @@ def test_write_read_cls(mock_cls, tmp_path):
214242

215243
cls = read_cls(filename, workdir=workdir)
216244

217-
assert cls.keys() == mock_cls.keys()
245+
assert list(cls.keys()) == list(mock_cls.keys())
218246
for key in mock_cls:
219247
assert key in cls
220248
cl = cls[key]

0 commit comments

Comments
 (0)