Skip to content

Commit

Permalink
Flag to include radec columns in index creation. (#378)
Browse files Browse the repository at this point in the history
* Flag to include radec columns in index creation.

* Preserve order of extra columns.

* Remove extra line.

* pylint whackamole

* Expand list comprehension to for-loop.
  • Loading branch information
delucchi-cmu authored Aug 16, 2024
1 parent 941b39a commit d87041a
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 0 deletions.
18 changes: 18 additions & 0 deletions src/hipscat_import/index/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class IndexArguments(RuntimeArguments):
"""Include the hipscat spatial partition index."""
include_order_pixel: bool = True
"""Include partitioning columns, Norder, Dir, and Npix. You probably want to keep these!"""
include_radec: bool = False
"""Include the ra/dec coordinates of the row."""
drop_duplicates: bool = True
"""Should we check for duplicate rows (including new indexing column),
and remove duplicates before writing to new index catalog?
Expand Down Expand Up @@ -61,6 +63,21 @@ def _check_arguments(self):
self.input_catalog = Catalog.read_from_hipscat(
catalog_path=self.input_catalog_path, storage_options=self.input_storage_options
)
if self.include_radec:
catalog_info = self.input_catalog.catalog_info
self.extra_columns.extend([catalog_info.ra_column, catalog_info.dec_column])
if len(self.extra_columns) > 0:
# check that they're in the schema
schema = self.input_catalog.schema
missing_fields = [x for x in self.extra_columns if schema.get_field_index(x) == -1]
if len(missing_fields):
raise ValueError(f"Some requested columns not in input catalog ({','.join(missing_fields)})")
# Remove duplicates, preserving order
extra_columns = []
for x in self.extra_columns:
if x not in extra_columns:
extra_columns.append(x)
self.extra_columns = extra_columns

if self.compute_partition_size < 100_000:
raise ValueError("compute_partition_size must be at least 100_000")
Expand All @@ -84,4 +101,5 @@ def additional_runtime_provenance_info(self) -> dict:
"extra_columns": self.extra_columns,
"include_hipscat_index": self.include_hipscat_index,
"include_order_pixel": self.include_order_pixel,
"include_radec": self.include_radec,
}
31 changes: 31 additions & 0 deletions tests/hipscat_import/index/test_index_argument.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Tests of argument validation"""

import re

import pytest

from hipscat_import.index.arguments import IndexArguments
Expand Down Expand Up @@ -111,6 +113,35 @@ def test_column_inclusion_args(tmp_path, small_sky_object_catalog):
)


def test_extra_columns(tmp_path, small_sky_object_catalog):
args = IndexArguments(
input_catalog_path=small_sky_object_catalog,
indexing_column="id",
output_path=tmp_path,
output_artifact_name="small_sky_object_index",
extra_columns=["_hipscat_index"],
)
assert args.extra_columns == ["_hipscat_index"]

args = IndexArguments(
input_catalog_path=small_sky_object_catalog,
indexing_column="id",
output_path=tmp_path,
output_artifact_name="small_sky_object_index",
include_radec=True,
)
assert args.extra_columns == ["ra", "dec"]

with pytest.raises(ValueError, match=re.escape("not in input catalog (mag_r)")):
IndexArguments(
input_catalog_path=small_sky_object_catalog,
indexing_column="id",
output_path=tmp_path,
output_artifact_name="small_sky_object_index",
extra_columns=["mag_r"],
)


def test_compute_partition_size(tmp_path, small_sky_object_catalog):
"""Test validation of compute_partition_size."""
with pytest.raises(ValueError, match="compute_partition_size"):
Expand Down

0 comments on commit d87041a

Please sign in to comment.