Skip to content

Commit

Permalink
improved mapper interface
Browse files Browse the repository at this point in the history
  • Loading branch information
ntessore committed Apr 10, 2024
1 parent cfd2109 commit 214596e
Show file tree
Hide file tree
Showing 9 changed files with 147 additions and 379 deletions.
37 changes: 1 addition & 36 deletions heracles/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from collections import UserDict
from collections.abc import Mapping, Sequence
from typing import Any, Callable, TypeVar
from typing import TypeVar

import numpy as np

Expand Down Expand Up @@ -54,33 +54,6 @@ def toc_filter(obj, include=None, exclude=None):
raise TypeError(msg)


def multi_value_getter(obj: T | Mapping[Any, T]) -> Callable[[Any], T]:
"""Return a getter for values or mappings."""
if isinstance(obj, Mapping):

def getter(key: Any) -> T:
if isinstance(key, Sequence):
t = tuple(key)
else:
t = (key,)
while t:
if t in obj:
return obj[t]
if len(t) == 1 and t[0] in obj:
return obj[t[0]]
t = t[:-1]
if t in obj:
return obj[t]
raise KeyError(key)

else:

def getter(key: Any) -> T:
return obj

return getter


# subclassing UserDict here since that returns the correct type from methods
# such as __copy__(), __or__(), etc.
class TocDict(UserDict):
Expand Down Expand Up @@ -143,11 +116,3 @@ def update_metadata(array, *sources, **metadata):
raise ValueError(msg)
# set the new dtype in array
array.dtype = dt


def items_with_suffix(d: Mapping[str, Any], suffix: str) -> Mapping[str, Any]:
"""
Return items from *d* where keys end in *suffix*. Returns a mapping
where *suffix* is removed from keys.
"""
return {k.removesuffix(suffix): v for k, v in d.items() if k.endswith(suffix)}
42 changes: 18 additions & 24 deletions heracles/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from abc import ABCMeta, abstractmethod
from functools import partial
from itertools import combinations_with_replacement, product
from types import MappingProxyType
from typing import TYPE_CHECKING

import coroutines
Expand All @@ -34,7 +33,7 @@

if TYPE_CHECKING:
from collections.abc import AsyncIterable, Mapping, Sequence
from typing import Any, TypeGuard
from typing import TypeGuard

from numpy.typing import ArrayLike

Expand Down Expand Up @@ -89,9 +88,6 @@ def __init__(self, *columns: str, mask: str | None = None) -> None:
super().__init__()
self.__columns = self._init_columns(*columns) if columns else None
self.__mask = mask
self._metadata: dict[str, Any] = {}
if (spin := self.__spin) is not None:
self._metadata["spin"] = spin

@classmethod
def _init_columns(cls, *columns: str) -> Columns:
Expand Down Expand Up @@ -128,11 +124,6 @@ def columns_or_error(self) -> Columns:
raise ValueError(msg)
return self.__columns

@property
def metadata(self) -> Mapping[str, Any]:
"""Return the static metadata for this field."""
return MappingProxyType(self._metadata)

@property
def spin(self) -> int:
"""Spin weight of field."""
Expand Down Expand Up @@ -230,7 +221,7 @@ async def __call__(
col = self.columns_or_error

# position map
pos = np.zeros(mapper.size, mapper.dtype)
pos = mapper.create(spin=self.spin)

# keep track of the total number of galaxies
ngal = 0
Expand Down Expand Up @@ -292,7 +283,7 @@ async def __call__(
bias = ngal / (4 * np.pi) * mapper.area**2 / nbar**2

# set metadata of array
update_metadata(pos, self, catalog, mapper, nbar=nbar, bias=bias)
update_metadata(pos, catalog, nbar=nbar, bias=bias)

# return the position map
return pos
Expand All @@ -316,7 +307,7 @@ async def __call__(
*col, wcol = self.columns_or_error

# scalar field map
val = np.zeros(mapper.size, mapper.dtype)
val = mapper.create(spin=self.spin)

# total weighted variance from online algorithm
ngal = 0
Expand Down Expand Up @@ -365,7 +356,7 @@ async def __call__(
bias = 4 * np.pi * vbar**2 * (var / wmean**2) / ngal

# set metadata of array
update_metadata(val, self, catalog, mapper, wbar=wbar, bias=bias)
update_metadata(val, catalog, wbar=wbar, bias=bias)

# return the value map
return val
Expand Down Expand Up @@ -394,7 +385,7 @@ async def __call__(
*col, wcol = self.columns_or_error

# complex map with real and imaginary part
val = np.zeros((2, mapper.size), mapper.dtype)
val = mapper.create(2, spin=self.spin)

# total weighted variance from online algorithm
ngal = 0
Expand Down Expand Up @@ -443,7 +434,7 @@ async def __call__(
bias = 2 * np.pi * vbar**2 * (var / wmean**2) / ngal

# set metadata of array
update_metadata(val, self, catalog, mapper, wbar=wbar, bias=bias)
update_metadata(val, catalog, wbar=wbar, bias=bias)

# return the shear map
return val
Expand All @@ -467,22 +458,25 @@ async def __call__(
msg = "no visibility map in catalog"
raise ValueError(msg)

# create new visibility map
out = mapper.create(spin=self.spin)

# warn if visibility is changing resolution
if vmap.size != mapper.size:
if vmap.size != out.size:
import healpy as hp

warnings.warn(
f"changing NSIDE of visibility map "
f"from {hp.get_nside(vmap)} to {mapper.nside}",
)
vmap = hp.ud_grade(vmap, mapper.nside)
out[:] = hp.ud_grade(vmap, mapper.nside)
else:
# make a copy for updates to metadata
vmap = np.copy(vmap)
# copy pixel values
out[:] = vmap

update_metadata(vmap, self, catalog, mapper)
update_metadata(out, catalog)

return vmap
return out


class Weights(Field, spin=0):
Expand All @@ -503,7 +497,7 @@ async def __call__(
*col, wcol = self.columns_or_error

# weight map
wht = np.zeros(mapper.size, mapper.dtype)
wht = mapper.create(spin=self.spin)

# total weighted variance from online algorithm
ngal = 0
Expand Down Expand Up @@ -550,7 +544,7 @@ async def __call__(
bias = 4 * np.pi * vbar**2 * (w2mean / wmean**2) / ngal

# set metadata of array
update_metadata(wht, self, catalog, mapper, wbar=wbar, bias=bias)
update_metadata(wht, catalog, wbar=wbar, bias=bias)

# return the weight map
return wht
Expand Down
4 changes: 1 addition & 3 deletions heracles/maps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,10 @@
__all__ = (
"Healpix",
"Mapper",
"get_kernels",
"map_catalogs",
"mapper_from_dict",
"transform_maps",
)

from ._healpix import Healpix
from ._mapper import Mapper, get_kernels, mapper_from_dict
from ._mapper import Mapper
from ._mapping import map_catalogs, transform_maps
Loading

0 comments on commit 214596e

Please sign in to comment.