From d87041a03ccc9dec6378e0a36f1c99058cb0712a Mon Sep 17 00:00:00 2001 From: Melissa DeLucchi <113376043+delucchi-cmu@users.noreply.github.com> Date: Fri, 16 Aug 2024 13:59:35 -0400 Subject: [PATCH] Flag to include radec columns in index creation. (#378) * Flag to include radec columns in index creation. * Preserve order of extra columns. * Remove extra line. * pylint whackamole * Expand list comprehension to for-loop. --- src/hipscat_import/index/arguments.py | 18 +++++++++++ .../index/test_index_argument.py | 31 +++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/src/hipscat_import/index/arguments.py b/src/hipscat_import/index/arguments.py index beacca19..13d7d25e 100644 --- a/src/hipscat_import/index/arguments.py +++ b/src/hipscat_import/index/arguments.py @@ -27,6 +27,8 @@ class IndexArguments(RuntimeArguments): """Include the hipscat spatial partition index.""" include_order_pixel: bool = True """Include partitioning columns, Norder, Dir, and Npix. You probably want to keep these!""" + include_radec: bool = False + """Include the ra/dec coordinates of the row.""" drop_duplicates: bool = True """Should we check for duplicate rows (including new indexing column), and remove duplicates before writing to new index catalog? @@ -61,6 +63,21 @@ def _check_arguments(self): self.input_catalog = Catalog.read_from_hipscat( catalog_path=self.input_catalog_path, storage_options=self.input_storage_options ) + if self.include_radec: + catalog_info = self.input_catalog.catalog_info + self.extra_columns.extend([catalog_info.ra_column, catalog_info.dec_column]) + if len(self.extra_columns) > 0: + # check that they're in the schema + schema = self.input_catalog.schema + missing_fields = [x for x in self.extra_columns if schema.get_field_index(x) == -1] + if len(missing_fields): + raise ValueError(f"Some requested columns not in input catalog ({','.join(missing_fields)})") + # Remove duplicates, preserving order + extra_columns = [] + for x in self.extra_columns: + if x not in extra_columns: + extra_columns.append(x) + self.extra_columns = extra_columns if self.compute_partition_size < 100_000: raise ValueError("compute_partition_size must be at least 100_000") @@ -84,4 +101,5 @@ def additional_runtime_provenance_info(self) -> dict: "extra_columns": self.extra_columns, "include_hipscat_index": self.include_hipscat_index, "include_order_pixel": self.include_order_pixel, + "include_radec": self.include_radec, } diff --git a/tests/hipscat_import/index/test_index_argument.py b/tests/hipscat_import/index/test_index_argument.py index bc56996c..2d97432e 100644 --- a/tests/hipscat_import/index/test_index_argument.py +++ b/tests/hipscat_import/index/test_index_argument.py @@ -1,5 +1,7 @@ """Tests of argument validation""" +import re + import pytest from hipscat_import.index.arguments import IndexArguments @@ -111,6 +113,35 @@ def test_column_inclusion_args(tmp_path, small_sky_object_catalog): ) +def test_extra_columns(tmp_path, small_sky_object_catalog): + args = IndexArguments( + input_catalog_path=small_sky_object_catalog, + indexing_column="id", + output_path=tmp_path, + output_artifact_name="small_sky_object_index", + extra_columns=["_hipscat_index"], + ) + assert args.extra_columns == ["_hipscat_index"] + + args = IndexArguments( + input_catalog_path=small_sky_object_catalog, + indexing_column="id", + output_path=tmp_path, + output_artifact_name="small_sky_object_index", + include_radec=True, + ) + assert args.extra_columns == ["ra", "dec"] + + with pytest.raises(ValueError, match=re.escape("not in input catalog (mag_r)")): + IndexArguments( + input_catalog_path=small_sky_object_catalog, + indexing_column="id", + output_path=tmp_path, + output_artifact_name="small_sky_object_index", + extra_columns=["mag_r"], + ) + + def test_compute_partition_size(tmp_path, small_sky_object_catalog): """Test validation of compute_partition_size.""" with pytest.raises(ValueError, match="compute_partition_size"):