Skip to content

Commit

Permalink
make mapper a property of fields
Browse files Browse the repository at this point in the history
  • Loading branch information
ntessore committed Apr 10, 2024
1 parent 28b8a31 commit d36aa7f
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 80 deletions.
43 changes: 36 additions & 7 deletions heracles/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,15 @@ def __init_subclass__(cls, *, spin: int | None = None) -> None:
break
cls.__ncol = (ncol - nopt, ncol)

def __init__(self, *columns: str, mask: str | None = None) -> None:
def __init__(
self,
mapper: Mapper | None,
*columns: str,
mask: str | None = None,
) -> None:
"""Initialise the field."""
super().__init__()
self.__mapper = mapper
self.__columns = self._init_columns(*columns) if columns else None
self.__mask = mask

Expand All @@ -110,6 +116,20 @@ def _init_columns(cls, *columns: str) -> Columns:
raise ValueError(msg)
return columns + (None,) * (nmax - len(columns))

@property
def mapper(self) -> Mapper | None:
"""Return the mapper used by this field."""
return self.__mapper

@property
def mapper_or_error(self) -> Mapper:
"""Return the mapper used by this field, or raise a :class:`ValueError`
if not set."""
if self.__mapper is None:
msg = "no mapper for field"
raise ValueError(msg)
return self.__mapper

@property
def columns(self) -> Columns | None:
"""Return the catalogue columns used by this field."""
Expand Down Expand Up @@ -143,7 +163,6 @@ def mask(self) -> str | None:
async def __call__(
self,
catalog: Catalog,
mapper: Mapper,
*,
progress: ProgressTask | None = None,
) -> ArrayLike:
Expand Down Expand Up @@ -211,12 +230,14 @@ def nbar(self, nbar: float | None) -> None:
async def __call__(
self,
catalog: Catalog,
mapper: Mapper,
*,
progress: ProgressTask | None = None,
) -> ArrayLike:
"""Map the given catalogue."""

# get mapper
mapper = self.mapper_or_error

# get catalogue column definition
col = self.columns_or_error

Expand Down Expand Up @@ -297,12 +318,14 @@ class ScalarField(Field, spin=0):
async def __call__(
self,
catalog: Catalog,
mapper: Mapper,
*,
progress: ProgressTask | None = None,
) -> ArrayLike:
"""Map real values from catalogue to HEALPix map."""

# get mapper
mapper = self.mapper_or_error

# get the column definition of the catalogue
*col, wcol = self.columns_or_error

Expand Down Expand Up @@ -375,12 +398,14 @@ class ComplexField(Field, spin=0):
async def __call__(
self,
catalog: Catalog,
mapper: Mapper,
*,
progress: ProgressTask | None = None,
) -> ArrayLike:
"""Map complex values from catalogue to HEALPix map."""

# get mapper
mapper = self.mapper_or_error

# get the column definition of the catalogue
*col, wcol = self.columns_or_error

Expand Down Expand Up @@ -446,12 +471,14 @@ class Visibility(Field, spin=0):
async def __call__(
self,
catalog: Catalog,
mapper: Mapper,
*,
progress: ProgressTask | None = None,
) -> ArrayLike:
"""Create a visibility map from the given catalogue."""

# get mapper
mapper = self.mapper_or_error

# make sure that catalogue has a visibility map
vmap = catalog.visibility
if vmap is None:
Expand Down Expand Up @@ -487,12 +514,14 @@ class Weights(Field, spin=0):
async def __call__(
self,
catalog: Catalog,
mapper: Mapper,
*,
progress: ProgressTask | None = None,
) -> ArrayLike:
"""Map catalogue weights."""

# get mapper
mapper = self.mapper_or_error

# get the columns for this field
*col, wcol = self.columns_or_error

Expand Down
18 changes: 10 additions & 8 deletions heracles/maps/_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,11 @@
from heracles.fields import Field
from heracles.progress import Progress, ProgressTask

from ._mapper import Mapper


async def _map_progress(
key: tuple[Any, ...],
field: Field,
catalog: Catalog,
mapper: Mapper,
progress: Progress | None,
) -> NDArray:
"""
Expand All @@ -59,7 +56,7 @@ async def _map_progress(
else:
task = None

result = await field(catalog, mapper, progress=task)
result = await field(catalog, progress=task)

if progress is not None:
task.remove()
Expand All @@ -69,7 +66,6 @@ async def _map_progress(


def map_catalogs(
mapper: Mapper,
fields: Mapping[Any, Field],
catalogs: Mapping[Any, Catalog],
*,
Expand Down Expand Up @@ -116,7 +112,7 @@ def map_catalogs(
for key, field, catalog in items:
if toc_match(key, include, exclude):
keys.append(key)
coros.append(_map_progress(key, field, catalog, mapper, prog))
coros.append(_map_progress(key, field, catalog, prog))

# run all coroutines concurrently
try:
Expand All @@ -141,7 +137,7 @@ def map_catalogs(


def transform_maps(
mapper: Mapper,
fields: Mapping[Any, Field],
maps: Mapping[tuple[Any, Any], NDArray],
*,
out: MutableMapping[tuple[Any, Any], NDArray] | None = None,
Expand Down Expand Up @@ -175,7 +171,13 @@ def transform_maps(
total=None,
)

alms = mapper.transform(m)
try:
field = fields[k]
except KeyError:
msg = f"unknown field name: {k}"
raise ValueError(msg)

alms = field.mapper_or_error.transform(m)

if isinstance(alms, tuple):
out[f"{k}_E", i] = alms[0]
Expand Down
Loading

0 comments on commit d36aa7f

Please sign in to comment.