diff --git a/src/hipscat_import/catalog/arguments.py b/src/hipscat_import/catalog/arguments.py index 6dda8ddd..31651255 100644 --- a/src/hipscat_import/catalog/arguments.py +++ b/src/hipscat_import/catalog/arguments.py @@ -40,6 +40,8 @@ class ImportArguments(RuntimeArguments): """column for right ascension""" dec_column: str = "dec" """column for declination""" + use_hipscat_index: bool = False + """use an existing hipscat spatial index as the position, instead of ra/dec""" id_column: str = "id" """column for survey identifier, or other sortable column""" add_hipscat_index: bool = True @@ -140,6 +142,7 @@ def additional_runtime_provenance_info(self) -> dict: "input_file_list": self.input_file_list, "ra_column": self.ra_column, "dec_column": self.dec_column, + "use_hipscat_index": self.use_hipscat_index, "id_column": self.id_column, "constant_healpix_order": self.constant_healpix_order, "highest_healpix_order": self.highest_healpix_order, diff --git a/src/hipscat_import/catalog/map_reduce.py b/src/hipscat_import/catalog/map_reduce.py index 8ba4d8a3..38d99b1b 100644 --- a/src/hipscat_import/catalog/map_reduce.py +++ b/src/hipscat_import/catalog/map_reduce.py @@ -6,6 +6,7 @@ import pyarrow.parquet as pq from hipscat import pixel_math from hipscat.io import FilePointer, file_io, paths +from hipscat.pixel_math.hipscat_id import HIPSCAT_ID_COLUMN, hipscat_id_to_healpix from hipscat_import.catalog.file_readers import InputReader from hipscat_import.catalog.resume_plan import ResumePlan @@ -53,6 +54,7 @@ def _iterate_input_file( highest_order, ra_column, dec_column, + use_hipscat_index = False, ): """Helper function to handle input file reading and healpix pixel calculation""" if not file_reader: @@ -61,18 +63,28 @@ def _iterate_input_file( required_columns = [ra_column, dec_column] for chunk_number, data in enumerate(file_reader.read(input_file)): - if not all(x in data.columns for x in required_columns): - raise ValueError( - f"Invalid column names in input file: {ra_column}, {dec_column} not in {input_file}" + if use_hipscat_index: + if data.index.name == HIPSCAT_ID_COLUMN: + mapped_pixels = hipscat_id_to_healpix(data.index, target_order=highest_order) + elif HIPSCAT_ID_COLUMN in data.columns: + mapped_pixels = hipscat_id_to_healpix(data[HIPSCAT_ID_COLUMN], target_order=highest_order) + else: + raise ValueError( + f"Invalid column names in input file: {HIPSCAT_ID_COLUMN} not in {input_file}" + ) + else: + if not all(x in data.columns for x in required_columns): + raise ValueError( + f"Invalid column names in input file: {', '.join(required_columns)} not in {input_file}" + ) + # Set up the pixel data + mapped_pixels = hp.ang2pix( + 2**highest_order, + data[ra_column].values, + data[dec_column].values, + lonlat=True, + nest=True, ) - # Set up the pixel data - mapped_pixels = hp.ang2pix( - 2**highest_order, - data[ra_column].values, - data[dec_column].values, - lonlat=True, - nest=True, - ) yield chunk_number, data, mapped_pixels @@ -84,6 +96,7 @@ def map_to_pixels( highest_order, ra_column, dec_column, + use_hipscat_index = False ): """Map a file of input objects to their healpix pixels. @@ -107,7 +120,7 @@ def map_to_pixels( """ histo = pixel_math.empty_histogram(highest_order) for _, _, mapped_pixels in _iterate_input_file( - input_file, file_reader, highest_order, ra_column, dec_column + input_file, file_reader, highest_order, ra_column, dec_column, use_hipscat_index ): mapped_pixel, count_at_pixel = np.unique(mapped_pixels, return_counts=True) histo[mapped_pixel] += count_at_pixel.astype(np.int64) @@ -124,6 +137,7 @@ def split_pixels( cache_shard_path: FilePointer, resume_path: FilePointer, alignment=None, + use_hipscat_index = False, ): """Map a file of input objects to their healpix pixels and split into shards. @@ -144,7 +158,7 @@ def split_pixels( FileNotFoundError: if the file does not exist, or is a directory """ for chunk_number, data, mapped_pixels in _iterate_input_file( - input_file, file_reader, highest_order, ra_column, dec_column + input_file, file_reader, highest_order, ra_column, dec_column, use_hipscat_index ): aligned_pixels = alignment[mapped_pixels] unique_pixels, unique_inverse = np.unique(aligned_pixels, return_inverse=True) @@ -180,6 +194,7 @@ def reduce_pixel_shards( ra_column, dec_column, id_column, + use_hipscat_index = False, add_hipscat_index=True, delete_input_files=True, use_schema_file="", @@ -259,8 +274,8 @@ def reduce_pixel_shards( dataframe = merged_table.to_pandas() if id_column: dataframe = dataframe.sort_values(id_column) - if add_hipscat_index: - dataframe["_hipscat_index"] = pixel_math.compute_hipscat_id( + if add_hipscat_index and not use_hipscat_index: + dataframe[HIPSCAT_ID_COLUMN] = pixel_math.compute_hipscat_id( dataframe[ra_column].values, dataframe[dec_column].values, ) @@ -277,7 +292,7 @@ def reduce_pixel_shards( ## If we had a meaningful index before, preserve it as a column. if _has_named_index(dataframe): dataframe = dataframe.reset_index() - dataframe = dataframe.set_index("_hipscat_index").sort_index() + dataframe = dataframe.set_index(HIPSCAT_ID_COLUMN).sort_index() dataframe.to_parquet(destination_file) del dataframe, merged_table, tables diff --git a/src/hipscat_import/catalog/run_import.py b/src/hipscat_import/catalog/run_import.py index 218bd539..36a7476c 100644 --- a/src/hipscat_import/catalog/run_import.py +++ b/src/hipscat_import/catalog/run_import.py @@ -35,6 +35,7 @@ def _map_pixels(args, client): highest_order=args.mapping_healpix_order, ra_column=args.ra_column, dec_column=args.dec_column, + use_hipscat_index=args.use_hipscat_index, ) ) args.resume_plan.wait_for_mapping(futures) @@ -62,6 +63,7 @@ def _split_pixels(args, alignment_future, client): cache_shard_path=args.tmp_path, resume_path=args.resume_plan.tmp_path, alignment=alignment_future, + use_hipscat_index=args.use_hipscat_index, ) ) @@ -96,6 +98,7 @@ def _reduce_pixels(args, destination_pixel_map, client): id_column=args.id_column, add_hipscat_index=args.add_hipscat_index, use_schema_file=args.use_schema_file, + use_hipscat_index=args.use_hipscat_index, ) ) diff --git a/src/hipscat_import/index/map_reduce.py b/src/hipscat_import/index/map_reduce.py index a7cb5399..e446a2a9 100644 --- a/src/hipscat_import/index/map_reduce.py +++ b/src/hipscat_import/index/map_reduce.py @@ -4,6 +4,7 @@ import numpy as np from dask.distributed import progress, wait from hipscat.io import file_io +from hipscat.pixel_math.hipscat_id import HIPSCAT_ID_COLUMN def create_index(args): @@ -31,7 +32,7 @@ def create_index(args): data["Npix"] = data["Npix"].astype(np.int32) data = data.reset_index() if not args.include_hipscat_index: - data = data.drop(columns=["_hipscat_index"]) + data = data.drop(columns=[HIPSCAT_ID_COLUMN]) data = data.repartition(partition_size=args.compute_partition_size) data = data.set_index(args.indexing_column) result = data.to_parquet( diff --git a/tests/hipscat_import/catalog/test_run_round_trip.py b/tests/hipscat_import/catalog/test_run_round_trip.py index 28ac954b..479b8bb4 100644 --- a/tests/hipscat_import/catalog/test_run_round_trip.py +++ b/tests/hipscat_import/catalog/test_run_round_trip.py @@ -353,3 +353,59 @@ def read(self, input_file): expected_ids = [*range(700, 831)] assert_parquet_file_ids(output_file, "id", expected_ids) + + +@pytest.mark.dask +def test_import_hipscat_index( + dask_client, + formats_dir, + assert_parquet_file_ids, + tmp_path, +): + """Test basic execution, using a previously-computed _hipscat_index column for spatial partitioning.""" + ## First, let's just check the assumptions we have about our input file: + ## - should have _hipscat_index as the indexed column + ## - should NOT have any columns like "ra" or "dec" + input_file = os.path.join(formats_dir, "hipscat_index.parquet") + + expected_ids = [*range(700, 831)] + assert_parquet_file_ids(input_file, "id", expected_ids) + + data_frame = pd.read_parquet(input_file, engine="pyarrow") + assert data_frame.index.name == "_hipscat_index" + npt.assert_array_equal(data_frame.columns, ["id"]) + + args = ImportArguments( + output_catalog_name="using_hipscat_index", + input_file_list=[input_file], + input_format="parquet", + output_path=tmp_path, + dask_tmp=tmp_path, + use_hipscat_index=True, + add_hipscat_index=False, + highest_healpix_order=2, + pixel_threshold=3_000, + progress_bar=False, + id_column="_hipscat_index", + ) + + runner.run(args, dask_client) + + # Check that the catalog metadata file exists + catalog = Catalog.read_from_hipscat(args.catalog_path) + assert catalog.on_disk + assert catalog.catalog_path == args.catalog_path + assert catalog.catalog_info.total_rows == 131 + assert len(catalog.get_healpix_pixels()) == 1 + + # Check that the catalog parquet file exists and contains correct object IDs + output_file = os.path.join(args.catalog_path, "Norder=0", "Dir=0", "Npix=11.parquet") + + expected_ids = [*range(700, 831)] + assert_parquet_file_ids(output_file, "id", expected_ids) + data_frame = pd.read_parquet(output_file, engine="pyarrow") + assert data_frame.index.name == "_hipscat_index" + npt.assert_array_equal( + data_frame.columns, + ["id", "Norder", "Dir", "Npix"], + ) diff --git a/tests/hipscat_import/data/test_formats/hipscat_index.parquet b/tests/hipscat_import/data/test_formats/hipscat_index.parquet new file mode 100644 index 00000000..44bdf663 Binary files /dev/null and b/tests/hipscat_import/data/test_formats/hipscat_index.parquet differ