20
20
21
21
import logging
22
22
import os
23
+ import re
23
24
from collections .abc import MutableMapping
24
25
from functools import partial
25
26
from pathlib import Path
26
27
from types import MappingProxyType
28
+ from typing import TYPE_CHECKING
27
29
from warnings import warn
28
30
from weakref import WeakValueDictionary
29
31
32
34
33
35
from .core import TocDict , toc_match
34
36
37
+ if TYPE_CHECKING :
38
+ from typing import TypeAlias
39
+
35
40
logger = logging .getLogger (__name__ )
36
41
37
42
63
68
"bias" : "additive bias of spectrum" ,
64
69
}
65
70
71
+ # type for valid keys
72
+ _KeyType : "TypeAlias" = "str | int | tuple[_KeyType, ...]"
73
+
66
74
67
- def _extname_from_key (key ) :
75
+ def _extname_from_key (key : _KeyType ) -> str :
68
76
"""
69
77
Return FITS extension name for a given key.
70
78
"""
71
- if not isinstance (key , tuple ):
72
- key = (key ,)
73
- return "," .join (map (str , key ))
79
+ if isinstance (key , tuple ):
80
+ names = list (map (_extname_from_key , key ))
81
+ c = ";" if any ("," in name for name in names ) else ","
82
+ return c .join (names )
83
+ return re .sub (r"\W+" , "_" , str (key ))
74
84
75
85
76
- def _key_from_extname (ext ) :
86
+ def _key_from_extname (extname : str ) -> _KeyType :
77
87
"""
78
88
Return key for a given FITS extension name.
79
89
"""
80
- return tuple (int (s ) if s .isdigit () else s for s in ext .split ("," ))
90
+ keys = extname .split (";" )
91
+ if len (keys ) > 1 :
92
+ return tuple (map (_key_from_extname , keys ))
93
+ keys = keys [0 ].split ("," )
94
+ if len (keys ) > 1 :
95
+ return tuple (map (_key_from_extname , keys ))
96
+ key = keys [0 ]
97
+ return int (key ) if key .isdigit () else key
81
98
82
99
83
100
def _iterfits (path , include = None , exclude = None ):
@@ -525,15 +542,15 @@ def write_cov(filename, cov, clobber=False, workdir=".", include=None, exclude=N
525
542
526
543
# reopen FITS for writing data
527
544
with fitsio .FITS (path , mode = "rw" , clobber = False ) as fits :
528
- for ( k1 , k2 ) , mat in cov .items ():
545
+ for key , mat in cov .items ():
529
546
# skip if not selected
530
- if not toc_match (( k1 , k2 ) , include = include , exclude = exclude ):
547
+ if not toc_match (key , include = include , exclude = exclude ):
531
548
continue
532
549
533
- # the cl extension name
534
- ext = _extname_from_key (k1 + k2 )
550
+ logger .info ("writing covariance matrix %s" , key )
535
551
536
- logger .info ("writing %s x %s covariance matrix" , k1 , k2 )
552
+ # the cov extension name
553
+ ext = _extname_from_key (key )
537
554
538
555
# write the covariance matrix as an image
539
556
fits .write_image (mat , extname = ext )
@@ -566,9 +583,7 @@ def read_cov(filename, workdir=".", *, include=None, exclude=None):
566
583
567
584
# iterate over valid HDUs in the file
568
585
for key , hdu in _iterfits (path , include = include , exclude = exclude ):
569
- k1 , k2 = key [: len (key ) // 2 ], key [len (key ) // 2 :]
570
-
571
- logger .info ("reading %s x %s covariance matrix" , k1 , k2 )
586
+ logger .info ("reading covariance matrix %s" , key )
572
587
573
588
# read the covariance matrix from the extension
574
589
mat = hdu .read ()
@@ -577,7 +592,7 @@ def read_cov(filename, workdir=".", *, include=None, exclude=None):
577
592
mat .dtype = np .dtype (mat .dtype , metadata = _read_metadata (hdu ))
578
593
579
594
# store in set
580
- cov [k1 , k2 ] = mat
595
+ cov [key ] = mat
581
596
582
597
logger .info ("done with %d covariance(s)" , len (cov ))
583
598
0 commit comments