Skip to content

Commit

Permalink
update CLI module
Browse files Browse the repository at this point in the history
  • Loading branch information
ntessore committed Apr 10, 2024
1 parent d2b85ed commit b6f01c8
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 72 deletions.
63 changes: 36 additions & 27 deletions heracles/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
if TYPE_CHECKING:
from numpy.typing import NDArray

from .fields import Field

# valid option keys
FIELD_TYPES = {
"positions": "heracles.fields:Positions",
Expand Down Expand Up @@ -100,7 +102,7 @@ def __init__(self) -> None:
# fully specify parent class
super().__init__(
defaults={
"kernel": "healpix",
"mapper": "healpix",
},
dict_type=dict,
allow_no_value=False,
Expand Down Expand Up @@ -157,6 +159,27 @@ def subsections(self, group: str) -> dict[str, str]:
return {s.rpartition(":")[-1].strip(): s for s in sections}


def mapper_from_config(config, section):
"""Construct a mapper instance from config."""

choices = {
"none": "none",
"healpix": "healpix",
}

mapper = config.getchoice(section, "mapper", choices)
if mapper == "none":
return None
if mapper == "healpix":
from .maps import Healpix

nside = config.getint(section, "nside")
lmax = config.getint(section, "lmax", fallback=None)
deconvolve = config.getint(section, "deconvolve", fallback=None)
return Healpix(nside, lmax, deconvolve=deconvolve)
return None


def field_from_config(config, section):
"""Construct a field instance from config."""

Expand All @@ -175,9 +198,10 @@ def field_from_config(config, section):
raise RuntimeError(msg) from None
else:
cls = _type
mapper = mapper_from_config(config, section)
columns = config.getlist(section, "columns", fallback=())
mask = config.get(section, "mask", fallback=None)
return cls(*columns, mask=mask)
return cls(mapper, *columns, mask=mask)


def fields_from_config(config):
Expand All @@ -188,16 +212,6 @@ def fields_from_config(config):
}


def mappers_from_config(config):
"""Construct all mapper instances from config."""
from .maps import mapper_from_dict

sections = config.subsections("fields")
return {
name: mapper_from_dict(config[section]) for name, section in sections.items()
}


def catalog_from_config(config, section, label=None, *, out=None):
"""Construct a catalogue instance from config."""

Expand Down Expand Up @@ -269,12 +283,6 @@ def catalogs_from_config(config):
return catalogs


def lmax_from_config(config):
"""Construct a dictionary with LMAX values for all fields."""
sections = config.subsections("fields")
return {name: config.getint(section, "lmax") for name, section in sections.items()}


def bins_from_config(config, section):
"""Construct angular bins from config."""

Expand Down Expand Up @@ -392,6 +400,7 @@ def configloader(path: Paths) -> ConfigParser:


def map_all_selections(
fields: Mapping[str, Field],
config: ConfigParser,
logger: logging.Logger,
progress: bool,
Expand All @@ -400,10 +409,8 @@ def map_all_selections(

from .maps import map_catalogs

# load catalogues, mappers, and fields to process
# load catalogues to process
catalogs = catalogs_from_config(config)
mappers = mappers_from_config(config)
fields = fields_from_config(config)

logger.info("fields %s", ", ".join(map(repr, fields)))

Expand All @@ -417,7 +424,6 @@ def map_all_selections(

# maps for single catalogue
yield map_catalogs(
mappers,
fields,
{key: catalog},
parallel=True, # process everything at this level in one go
Expand Down Expand Up @@ -455,9 +461,12 @@ def maps(
logger.info("reading configuration from %s", files)
config = loader(files)

# construct fields for mapping
fields = fields_from_config(config)

# iterator over the individual maps
# this generates maps on the fly
itermaps = map_all_selections(config, logger, progress)
itermaps = map_all_selections(fields, config, logger, progress)

# output goes into a FITS-backed tocdict so we don't fill memory up
out = MapFits(path, clobber=True)
Expand Down Expand Up @@ -500,16 +509,16 @@ def alms(
if healpix_datapath is not None:
Healpix.DATAPATH = healpix_datapath

# load the individual lmax values for each field into a dictionary
lmax = lmax_from_config(config)
# construct fields to get mappers for transform
fields = fields_from_config(config)

# process either catalogues or maps
# everything is loaded via iterators to keep memory use low
itermaps: Iterator
if maps:
itermaps = load_all_maps(maps, logger)
else:
itermaps = map_all_selections(config, logger, progress)
itermaps = map_all_selections(fields, config, logger, progress)

# output goes into a FITS-backed tocdict so we don't fill up memory
logger.info("writing alms to %s", path)
Expand All @@ -519,8 +528,8 @@ def alms(
for maps in itermaps:
logger.info("transforming %d maps", len(maps))
transform_maps(
fields,
maps,
lmax=lmax,
progress=progress,
out=out,
)
Expand Down
48 changes: 3 additions & 45 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def test_field_from_config():
}

config = ConfigParser()
config[config.default_section]["mapper"] = "none"
config.read_dict(
{
"a": {
Expand All @@ -135,9 +136,9 @@ def test_field_from_config():
with pytest.raises(RuntimeError, match="Internal error"):
field_from_config(config, "c")

mock.assert_called_once_with("COL1", "-COL2", mask="x")
mock.assert_called_once_with(None, "COL1", "-COL2", mask="x")
assert mock.return_value is a
other_mock.assert_called_once_with(mask=None)
other_mock.assert_called_once_with(None, mask=None)
assert other_mock.return_value is b


Expand Down Expand Up @@ -168,33 +169,6 @@ def test_fields_from_config(mock):
]


@patch("heracles.maps.mapper_from_dict")
def test_mappers_from_config(mock):
from heracles.cli import ConfigParser, mappers_from_config

config = ConfigParser()
config.read_dict(
{
"fields:a": {},
"fields:b": {},
"fields:c": {},
},
)

m = mappers_from_config(config)

assert m == {
"a": mock.return_value,
"b": mock.return_value,
"c": mock.return_value,
}
assert mock.call_args_list == [
((config["fields:a"],),),
((config["fields:b"],),),
((config["fields:c"],),),
]


@patch("heracles.io.read_vmap")
def test_catalog_from_config(mock):
from heracles.cli import ConfigParser, catalog_from_config
Expand Down Expand Up @@ -299,22 +273,6 @@ def test_catalogs_from_config(mock):
]


def test_lmax_from_config():
from heracles.cli import ConfigParser, lmax_from_config

config = ConfigParser()
config.read_dict(
{
"defaults": {"lmax": 30},
"fields:a": {"lmax": 10},
"fields:b": {"lmax": 20},
"fields:c": {}, # should use defaults
},
)

assert lmax_from_config(config) == {"a": 10, "b": 20, "c": 30}


def test_bins_from_config():
from heracles.cli import ConfigParser, bins_from_config

Expand Down

0 comments on commit b6f01c8

Please sign in to comment.