diff --git a/pyproject.toml b/pyproject.toml index 2257d37..931aa25 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ dependencies = [ "dask[complete]>=2024.3.0", # Includes dask expressions. "deprecated", "healpy", - "hipscat >=0.3.0", + "hipscat >=0.3.4", "ipykernel", # Support for Jupyter notebooks "numpy", "pandas", diff --git a/src/hipscat_import/margin_cache/margin_cache_arguments.py b/src/hipscat_import/margin_cache/margin_cache_arguments.py index 4f30ae6..8c22d4a 100644 --- a/src/hipscat_import/margin_cache/margin_cache_arguments.py +++ b/src/hipscat_import/margin_cache/margin_cache_arguments.py @@ -1,11 +1,12 @@ import warnings -from dataclasses import dataclass -from typing import Any, Dict, Union +from dataclasses import dataclass, field +from typing import Any, Dict, List, Union import healpy as hp from hipscat.catalog import Catalog from hipscat.catalog.margin_cache.margin_cache_catalog_info import MarginCacheCatalogInfo from hipscat.io.validation import is_valid_catalog +from hipscat.pixel_math.healpix_pixel import HealpixPixel from hipscat_import.runtime_arguments import RuntimeArguments @@ -35,6 +36,9 @@ class MarginCacheArguments(RuntimeArguments): """the path to the hipscat-formatted input catalog.""" input_storage_options: Union[Dict[Any, Any], None] = None """optional dictionary of abstract filesystem credentials for the INPUT.""" + debug_filter_pixel_list: List[HealpixPixel] = field(default_factory=list) + """debug setting. if provided, we will first filter the catalog to the pixels + provided. this can be useful for creating a margin over a subset of a catalog.""" def __post_init__(self): self._check_arguments() @@ -49,8 +53,12 @@ def _check_arguments(self): self.catalog = Catalog.read_from_hipscat( self.input_catalog_path, storage_options=self.input_storage_options ) + if len(self.debug_filter_pixel_list) > 0: + self.catalog = self.catalog.filter_from_pixel_list(self.debug_filter_pixel_list) + if len(self.catalog.get_healpix_pixels()) == 0: + raise ValueError("debug_filter_pixel_list has created empty catalog") - highest_order = self.catalog.partition_info.get_highest_order() + highest_order = int(self.catalog.partition_info.get_highest_order()) margin_pixel_k = highest_order + 1 if self.margin_order > -1: if self.margin_order < margin_pixel_k: @@ -85,4 +93,5 @@ def additional_runtime_provenance_info(self) -> dict: "input_catalog_path": self.input_catalog_path, "margin_threshold": self.margin_threshold, "margin_order": self.margin_order, + "debug_filter_pixel_list": self.debug_filter_pixel_list, } diff --git a/tests/hipscat_import/margin_cache/test_arguments_margin_cache.py b/tests/hipscat_import/margin_cache/test_arguments_margin_cache.py index 57564ce..546fad0 100644 --- a/tests/hipscat_import/margin_cache/test_arguments_margin_cache.py +++ b/tests/hipscat_import/margin_cache/test_arguments_margin_cache.py @@ -1,11 +1,11 @@ """Tests of margin cache generation arguments""" import pytest +from hipscat.io import write_metadata +from hipscat.pixel_math.healpix_pixel import HealpixPixel from hipscat_import.margin_cache.margin_cache_arguments import MarginCacheArguments -# pylint: disable=protected-access - def test_empty_required(tmp_path): """*Most* required arguments are provided.""" @@ -64,6 +64,42 @@ def test_margin_order_invalid(small_sky_source_catalog, tmp_path): ) +def test_debug_filter_pixel_list(small_sky_source_catalog, tmp_path): + """Ensure we can generate catalog with a filtereed list of pixels, and + that we raise an exception when the filter results in an empty catalog.""" + args = MarginCacheArguments( + margin_threshold=5.0, + input_catalog_path=small_sky_source_catalog, + output_path=tmp_path, + output_artifact_name="catalog_cache", + margin_order=4, + debug_filter_pixel_list=[HealpixPixel(0, 11)], + ) + + assert len(args.catalog.get_healpix_pixels()) == 13 + + args = MarginCacheArguments( + margin_threshold=5.0, + input_catalog_path=small_sky_source_catalog, + output_path=tmp_path, + output_artifact_name="catalog_cache", + margin_order=4, + debug_filter_pixel_list=[HealpixPixel(1, 44)], + ) + + assert len(args.catalog.get_healpix_pixels()) == 4 + + with pytest.raises(ValueError, match="debug_filter_pixel_list"): + MarginCacheArguments( + margin_threshold=5.0, + input_catalog_path=small_sky_source_catalog, + output_path=tmp_path, + output_artifact_name="catalog_cache", + margin_order=4, + debug_filter_pixel_list=[HealpixPixel(0, 5)], + ) + + def test_margin_threshold_warns(small_sky_source_catalog, tmp_path): """Ensure we give a warning when margin_threshold is greater than margin_order resolution""" @@ -99,7 +135,12 @@ def test_provenance_info(small_sky_source_catalog, tmp_path): output_path=tmp_path, output_artifact_name="catalog_cache", margin_order=4, + debug_filter_pixel_list=[HealpixPixel(1, 44)], ) runtime_args = args.provenance_info()["runtime_args"] assert "margin_threshold" in runtime_args + + write_metadata.write_provenance_info( + catalog_base_dir=args.catalog_path, dataset_info=args.to_catalog_info(20_000), tool_args=runtime_args + )