diff --git a/heracles/fields.py b/heracles/fields.py index 0d7d915..ceb09be 100644 --- a/heracles/fields.py +++ b/heracles/fields.py @@ -83,9 +83,15 @@ def __init_subclass__(cls, *, spin: int | None = None) -> None: break cls.__ncol = (ncol - nopt, ncol) - def __init__(self, *columns: str, mask: str | None = None) -> None: + def __init__( + self, + mapper: Mapper | None, + *columns: str, + mask: str | None = None, + ) -> None: """Initialise the field.""" super().__init__() + self.__mapper = mapper self.__columns = self._init_columns(*columns) if columns else None self.__mask = mask @@ -110,6 +116,20 @@ def _init_columns(cls, *columns: str) -> Columns: raise ValueError(msg) return columns + (None,) * (nmax - len(columns)) + @property + def mapper(self) -> Mapper | None: + """Return the mapper used by this field.""" + return self.__mapper + + @property + def mapper_or_error(self) -> Mapper: + """Return the mapper used by this field, or raise a :class:`ValueError` + if not set.""" + if self.__mapper is None: + msg = "no mapper for field" + raise ValueError(msg) + return self.__mapper + @property def columns(self) -> Columns | None: """Return the catalogue columns used by this field.""" @@ -143,7 +163,6 @@ def mask(self) -> str | None: async def __call__( self, catalog: Catalog, - mapper: Mapper, *, progress: ProgressTask | None = None, ) -> ArrayLike: @@ -211,12 +230,14 @@ def nbar(self, nbar: float | None) -> None: async def __call__( self, catalog: Catalog, - mapper: Mapper, *, progress: ProgressTask | None = None, ) -> ArrayLike: """Map the given catalogue.""" + # get mapper + mapper = self.mapper_or_error + # get catalogue column definition col = self.columns_or_error @@ -297,12 +318,14 @@ class ScalarField(Field, spin=0): async def __call__( self, catalog: Catalog, - mapper: Mapper, *, progress: ProgressTask | None = None, ) -> ArrayLike: """Map real values from catalogue to HEALPix map.""" + # get mapper + mapper = self.mapper_or_error + # get the column definition of the catalogue *col, wcol = self.columns_or_error @@ -375,12 +398,14 @@ class ComplexField(Field, spin=0): async def __call__( self, catalog: Catalog, - mapper: Mapper, *, progress: ProgressTask | None = None, ) -> ArrayLike: """Map complex values from catalogue to HEALPix map.""" + # get mapper + mapper = self.mapper_or_error + # get the column definition of the catalogue *col, wcol = self.columns_or_error @@ -446,12 +471,14 @@ class Visibility(Field, spin=0): async def __call__( self, catalog: Catalog, - mapper: Mapper, *, progress: ProgressTask | None = None, ) -> ArrayLike: """Create a visibility map from the given catalogue.""" + # get mapper + mapper = self.mapper_or_error + # make sure that catalogue has a visibility map vmap = catalog.visibility if vmap is None: @@ -487,12 +514,14 @@ class Weights(Field, spin=0): async def __call__( self, catalog: Catalog, - mapper: Mapper, *, progress: ProgressTask | None = None, ) -> ArrayLike: """Map catalogue weights.""" + # get mapper + mapper = self.mapper_or_error + # get the columns for this field *col, wcol = self.columns_or_error diff --git a/heracles/maps/_mapping.py b/heracles/maps/_mapping.py index ed90fe2..018be45 100644 --- a/heracles/maps/_mapping.py +++ b/heracles/maps/_mapping.py @@ -38,14 +38,11 @@ from heracles.fields import Field from heracles.progress import Progress, ProgressTask - from ._mapper import Mapper - async def _map_progress( key: tuple[Any, ...], field: Field, catalog: Catalog, - mapper: Mapper, progress: Progress | None, ) -> NDArray: """ @@ -59,7 +56,7 @@ async def _map_progress( else: task = None - result = await field(catalog, mapper, progress=task) + result = await field(catalog, progress=task) if progress is not None: task.remove() @@ -69,7 +66,6 @@ async def _map_progress( def map_catalogs( - mapper: Mapper, fields: Mapping[Any, Field], catalogs: Mapping[Any, Catalog], *, @@ -116,7 +112,7 @@ def map_catalogs( for key, field, catalog in items: if toc_match(key, include, exclude): keys.append(key) - coros.append(_map_progress(key, field, catalog, mapper, prog)) + coros.append(_map_progress(key, field, catalog, prog)) # run all coroutines concurrently try: @@ -141,7 +137,7 @@ def map_catalogs( def transform_maps( - mapper: Mapper, + fields: Mapping[Any, Field], maps: Mapping[tuple[Any, Any], NDArray], *, out: MutableMapping[tuple[Any, Any], NDArray] | None = None, @@ -175,7 +171,13 @@ def transform_maps( total=None, ) - alms = mapper.transform(m) + try: + field = fields[k] + except KeyError: + msg = f"unknown field name: {k}" + raise ValueError(msg) + + alms = field.mapper_or_error.transform(m) if isinstance(alms, tuple): out[f"{k}_E", i] = alms[0] diff --git a/tests/test_fields.py b/tests/test_fields.py index c54f08a..6123f23 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -77,6 +77,8 @@ def catalog(page): def test_field_abc(): + from unittest.mock import Mock + from heracles.fields import Columns, Field with pytest.raises(TypeError): @@ -89,7 +91,7 @@ def _init_columns(self, *columns: str) -> Columns: async def __call__(self): pass - f = SpinLessField() + f = SpinLessField(None) with pytest.raises(ValueError, match="undefined spin weight"): f.spin @@ -100,24 +102,31 @@ class TestField(Field, spin=0): async def __call__(self): pass - f = TestField() + f = TestField(None) + assert f.mapper is None assert f.columns is None assert f.spin == 0 + with pytest.raises(ValueError): + f.mapper_or_error + with pytest.raises(ValueError): f.columns_or_error + mapper = Mock() + with pytest.raises(ValueError, match="accepts 2 to 3 columns"): - TestField("lon") + TestField(mapper, "lon") - f = TestField("lon", "lat", mask="W") + f = TestField(mapper, "lon", "lat", mask="W") + assert f.mapper is mapper assert f.columns == ("lon", "lat", None) assert f.mask == "W" -def test_visibility(mapper, vmap): +def test_visibility(nside, vmap): from contextlib import nullcontext from unittest.mock import Mock @@ -125,7 +134,6 @@ def test_visibility(mapper, vmap): from heracles.maps import Healpix fsky = vmap.mean() - nside = mapper.nside for nside_out in [nside // 2, nside, nside * 2]: catalog = Mock() @@ -134,10 +142,10 @@ def test_visibility(mapper, vmap): mapper_out = Healpix(nside_out) - f = Visibility() + f = Visibility(mapper_out) with pytest.warns(UserWarning) if nside != nside_out else nullcontext(): - result = coroutines.run(f(catalog, mapper_out)) + result = coroutines.run(f(catalog)) assert result is not vmap @@ -149,15 +157,16 @@ def test_visibility(mapper, vmap): "kernel": "healpix", "nside": mapper_out.nside, "lmax": mapper_out.lmax, + "deconv": mapper_out.deconvolve, } assert np.isclose(result.mean(), fsky) # test missing visibility map catalog = Mock() catalog.visibility = None - f = Visibility() + f = Visibility(mapper) with pytest.raises(ValueError, match="no visibility"): - coroutines.run(f(catalog, mapper)) + coroutines.run(f(catalog)) def test_positions(mapper, catalog, vmap): @@ -169,7 +178,7 @@ def test_positions(mapper, catalog, vmap): # normal mode: compute overdensity maps with metadata - f = Positions("ra", "dec") + f = Positions(mapper, "ra", "dec") # test some default settings assert f.spin == 0 @@ -177,7 +186,7 @@ def test_positions(mapper, catalog, vmap): assert f.nbar is None # create map - m = coroutines.run(f(catalog, mapper)) + m = coroutines.run(f(catalog)) nbar = 4.0 assert m.shape == (npix,) @@ -189,14 +198,15 @@ def test_positions(mapper, catalog, vmap): "kernel": "healpix", "nside": mapper.nside, "lmax": mapper.lmax, + "deconv": mapper.deconvolve, "bias": pytest.approx(bias / nbar**2), } np.testing.assert_array_equal(m, 0) # compute number count map - f = Positions("ra", "dec", overdensity=False) - m = coroutines.run(f(catalog, mapper)) + f = Positions(mapper, "ra", "dec", overdensity=False) + m = coroutines.run(f(catalog)) assert m.shape == (npix,) assert m.dtype.metadata == { @@ -207,6 +217,7 @@ def test_positions(mapper, catalog, vmap): "kernel": "healpix", "nside": mapper.nside, "lmax": mapper.lmax, + "deconv": mapper.deconvolve, "bias": pytest.approx(bias / nbar**2), } np.testing.assert_array_equal(m, 1.0) @@ -216,8 +227,8 @@ def test_positions(mapper, catalog, vmap): catalog.visibility = vmap nbar /= vmap.mean() - f = Positions("ra", "dec") - m = coroutines.run(f(catalog, mapper)) + f = Positions(mapper, "ra", "dec") + m = coroutines.run(f(catalog)) assert m.shape == (12 * mapper.nside**2,) assert m.dtype.metadata == { @@ -228,13 +239,14 @@ def test_positions(mapper, catalog, vmap): "kernel": "healpix", "nside": mapper.nside, "lmax": mapper.lmax, + "deconv": mapper.deconvolve, "bias": pytest.approx(bias / nbar**2), } # compute number count map with visibility map - f = Positions("ra", "dec", overdensity=False) - m = coroutines.run(f(catalog, mapper)) + f = Positions(mapper, "ra", "dec", overdensity=False) + m = coroutines.run(f(catalog)) assert m.shape == (12 * mapper.nside**2,) assert m.dtype.metadata == { @@ -245,14 +257,15 @@ def test_positions(mapper, catalog, vmap): "kernel": "healpix", "nside": mapper.nside, "lmax": mapper.lmax, + "deconv": mapper.deconvolve, "bias": pytest.approx(bias / nbar**2), } # compute overdensity maps with given (incorrect) nbar - f = Positions("ra", "dec", nbar=2 * nbar) + f = Positions(mapper, "ra", "dec", nbar=2 * nbar) with pytest.warns(UserWarning, match="mean density"): - m = coroutines.run(f(catalog, mapper)) + m = coroutines.run(f(catalog)) assert m.dtype.metadata["nbar"] == 2 * nbar assert m.dtype.metadata["bias"] == pytest.approx(bias / (2 * nbar) ** 2) @@ -263,8 +276,8 @@ def test_scalar_field(mapper, catalog): npix = 12 * mapper.nside**2 - f = ScalarField("ra", "dec", "g1", "w") - m = coroutines.run(f(catalog, mapper)) + f = ScalarField(mapper, "ra", "dec", "g1", "w") + m = coroutines.run(f(catalog)) w = next(iter(catalog))["w"] v = next(iter(catalog))["g1"] @@ -282,6 +295,7 @@ def test_scalar_field(mapper, catalog): "kernel": "healpix", "nside": mapper.nside, "lmax": mapper.lmax, + "deconv": mapper.deconvolve, "bias": pytest.approx(bias / wbar**2), } np.testing.assert_array_almost_equal(m, 0) @@ -292,8 +306,8 @@ def test_complex_field(mapper, catalog): npix = 12 * mapper.nside**2 - f = Spin2Field("ra", "dec", "g1", "g2", "w") - m = coroutines.run(f(catalog, mapper)) + f = Spin2Field(mapper, "ra", "dec", "g1", "g2", "w") + m = coroutines.run(f(catalog)) w = next(iter(catalog))["w"] re = next(iter(catalog))["g1"] @@ -312,6 +326,7 @@ def test_complex_field(mapper, catalog): "kernel": "healpix", "nside": mapper.nside, "lmax": mapper.lmax, + "deconv": mapper.deconvolve, "bias": pytest.approx(bias / wbar**2), } np.testing.assert_array_almost_equal(m, 0) @@ -322,8 +337,8 @@ def test_weights(mapper, catalog): npix = 12 * mapper.nside**2 - f = Weights("ra", "dec", "w") - m = coroutines.run(f(catalog, mapper)) + f = Weights(mapper, "ra", "dec", "w") + m = coroutines.run(f(catalog)) w = next(iter(catalog))["w"] v2 = (w**2).sum() @@ -340,6 +355,7 @@ def test_weights(mapper, catalog): "kernel": "healpix", "nside": mapper.nside, "lmax": mapper.lmax, + "deconv": mapper.deconvolve, "bias": pytest.approx(bias / wbar**2), } np.testing.assert_array_almost_equal(m, w / wbar) diff --git a/tests/test_maps.py b/tests/test_maps.py index 7171b61..f9a9269 100644 --- a/tests/test_maps.py +++ b/tests/test_maps.py @@ -158,22 +158,6 @@ def test_healpix_transform(mock_map2alm, rng): assert alms[1].dtype.metadata["nside"] == nside -class MockField: - def __init__(self): - self.args = [] - self.return_value = object() - - async def __call__(self, catalog, mapper, *, progress=None): - self.args.append((catalog, mapper)) - return self.return_value - - def assert_called_with(self, *args): - assert self.args[-1] == args - - def assert_any_call(self, *args): - assert args in self.args - - class MockCatalog: size = 10 page_size = 1 @@ -185,56 +169,55 @@ def __iter__(self): @pytest.mark.parametrize("parallel", [False, True]) def test_map_catalogs(parallel): - from heracles.maps import map_catalogs + from unittest.mock import AsyncMock - mapper = unittest.mock.Mock() + from heracles.maps import map_catalogs - fields = {"a": MockField(), "b": MockField(), "z": MockField()} + fields = {"a": AsyncMock(), "b": AsyncMock(), "z": AsyncMock()} catalogs = {"x": MockCatalog(), "y": MockCatalog()} - maps = map_catalogs(mapper, fields, catalogs, parallel=parallel) + maps = map_catalogs(fields, catalogs, parallel=parallel) for k in fields: for i in catalogs: - fields[k].assert_any_call(catalogs[i], mapper) + fields[k].assert_any_call(catalogs[i], progress=None) assert maps[k, i] is fields[k].return_value def test_map_catalogs_match(): + from unittest.mock import AsyncMock + from heracles.maps import map_catalogs - mapper = unittest.mock.Mock() - fields = {"a": MockField(), "b": MockField(), "c": MockField()} + fields = {"a": AsyncMock(), "b": AsyncMock(), "c": AsyncMock()} catalogs = {"x": MockCatalog(), "y": MockCatalog()} - maps = map_catalogs(mapper, fields, catalogs, include=[(..., "y")]) + maps = map_catalogs(fields, catalogs, include=[(..., "y")]) assert set(maps.keys()) == {("a", "y"), ("b", "y"), ("c", "y")} - maps = map_catalogs(mapper, fields, catalogs, exclude=[("a", ...)]) + maps = map_catalogs(fields, catalogs, exclude=[("a", ...)]) assert set(maps.keys()) == {("b", "x"), ("b", "y"), ("c", "x"), ("c", "y")} def test_transform_maps(rng): - from heracles.maps import transform_maps + from unittest.mock import Mock - alms_x = unittest.mock.Mock() - alms_ye = unittest.mock.Mock() - alms_yb = unittest.mock.Mock() + from heracles.maps import transform_maps - mapper = unittest.mock.Mock() - mapper.transform.side_effect = (alms_x, (alms_ye, alms_yb)) + x = Mock() + y = Mock() + x.mapper_or_error.transform.return_value = Mock() + y.mapper_or_error.transform.return_value = (Mock(), Mock()) - maps = { - ("X", 0): unittest.mock.Mock(), - ("Y", 1): unittest.mock.Mock(), - } + fields = {"X": x, "Y": y} + maps = {("X", 0): Mock(), ("Y", 1): Mock()} - alms = transform_maps(mapper, maps) + alms = transform_maps(fields, maps) assert len(alms) == 3 assert alms.keys() == {("X", 0), ("Y_E", 1), ("Y_B", 1)} - assert alms["X", 0] is alms_x - assert alms["Y_E", 1] is alms_ye - assert alms["Y_B", 1] is alms_yb + assert alms["X", 0] is x.mapper_or_error.transform.return_value + assert alms["Y_E", 1] is y.mapper_or_error.transform.return_value[0] + assert alms["Y_B", 1] is y.mapper_or_error.transform.return_value[1]