diff --git a/src/hipscat_import/association/arguments.py b/src/hipscat_import/association/arguments.py index 28b8f852..e1b717e2 100644 --- a/src/hipscat_import/association/arguments.py +++ b/src/hipscat_import/association/arguments.py @@ -2,7 +2,9 @@ from dataclasses import dataclass -from hipscat.catalog import CatalogParameters +from hipscat.catalog.association_catalog.association_catalog import ( + AssociationCatalogInfo, +) from hipscat_import.runtime_arguments import RuntimeArguments @@ -51,19 +53,20 @@ def _check_arguments(self): if self.compute_partition_size < 100_000: raise ValueError("compute_partition_size must be at least 100_000") - def to_catalog_parameters(self) -> CatalogParameters: - """Convert importing arguments into hipscat catalog parameters. - - Returns: - CatalogParameters for catalog being created. - """ - return CatalogParameters( - catalog_name=self.output_catalog_name, - catalog_type="association", - output_path=self.output_path, - ) + def to_catalog_info(self, total_rows) -> AssociationCatalogInfo: + """Catalog-type-specific dataset info.""" + info = { + "catalog_name": self.output_catalog_name, + "catalog_type": "association", + "total_rows": total_rows, + "primary_column": self.primary_id_column, + "primary_catalog": str(self.primary_input_catalog_path), + "join_column": self.join_id_column, + "join_catalog": str(self.join_input_catalog_path), + } + return AssociationCatalogInfo(**info) - def additional_runtime_provenance_info(self): + def additional_runtime_provenance_info(self) -> dict: return { "primary_input_catalog_path": str(self.primary_input_catalog_path), "primary_id_column": self.primary_id_column, diff --git a/src/hipscat_import/association/run_association.py b/src/hipscat_import/association/run_association.py index ba505967..461a36f1 100644 --- a/src/hipscat_import/association/run_association.py +++ b/src/hipscat_import/association/run_association.py @@ -9,7 +9,8 @@ from tqdm import tqdm from hipscat_import.association.arguments import AssociationArguments -from hipscat_import.association.map_reduce import map_association, reduce_association +from hipscat_import.association.map_reduce import (map_association, + reduce_association) def _validate_args(args): @@ -40,11 +41,17 @@ def run(args): ) as step_progress: # pylint: disable=duplicate-code # Very similar to /index/run_index.py - catalog_params = args.to_catalog_parameters() - catalog_params.total_rows = int(rows_written) - write_metadata.write_provenance_info(catalog_params, args.provenance_info()) + catalog_info = args.to_catalog_info(int(rows_written)) + write_metadata.write_provenance_info( + catalog_base_dir=args.catalog_path, + dataset_info=catalog_info, + tool_args=args.provenance_info(), + ) step_progress.update(1) - write_metadata.write_catalog_info(catalog_params) + catalog_info = args.to_catalog_info(total_rows=int(rows_written)) + write_metadata.write_catalog_info( + dataset_info=catalog_info, catalog_base_dir=args.catalog_path + ) step_progress.update(1) write_metadata.write_parquet_metadata(args.catalog_path) step_progress.update(1) diff --git a/src/hipscat_import/catalog/arguments.py b/src/hipscat_import/catalog/arguments.py index 7fa98135..38b59c11 100644 --- a/src/hipscat_import/catalog/arguments.py +++ b/src/hipscat_import/catalog/arguments.py @@ -6,7 +6,7 @@ from typing import Callable, List import pandas as pd -from hipscat.catalog import CatalogParameters +from hipscat.catalog.catalog import CatalogInfo from hipscat.io import FilePointer, file_io from hipscat.pixel_math import hipscat_id @@ -140,21 +140,19 @@ def _check_arguments(self): if not self.filter_function: self.filter_function = passthrough_filter_function - def to_catalog_parameters(self) -> CatalogParameters: - """Convert importing arguments into hipscat catalog parameters. - Returns: - CatalogParameters for catalog being created. - """ - return CatalogParameters( - catalog_name=self.output_catalog_name, - catalog_type=self.catalog_type, - output_path=self.output_path, - epoch=self.epoch, - ra_column=self.ra_column, - dec_column=self.dec_column, - ) + def to_catalog_info(self, total_rows) -> CatalogInfo: + """Catalog-type-specific dataset info.""" + info = { + "catalog_name": self.output_catalog_name, + "catalog_type": self.catalog_type, + "total_rows": total_rows, + "epoch": self.epoch, + "ra_column": self.ra_column, + "dec_column": self.dec_column, + } + return CatalogInfo(**info) - def additional_runtime_provenance_info(self): + def additional_runtime_provenance_info(self) -> dict: return { "catalog_name": self.output_catalog_name, "epoch": self.epoch, @@ -171,7 +169,9 @@ def additional_runtime_provenance_info(self): "pixel_threshold": self.pixel_threshold, "mapping_healpix_order": self.mapping_healpix_order, "debug_stats_only": self.debug_stats_only, - "file_reader_info": self.file_reader.provenance_info(), + "file_reader_info": self.file_reader.provenance_info() + if self.file_reader is not None + else {}, } @@ -180,6 +180,7 @@ def check_healpix_order_range( ): """Helper method to heck if the `order` is within the range determined by the `lower_bound` and `upper_bound`, inclusive. + Args: order (int): healpix order to check field_name (str): field name to use in the error message diff --git a/src/hipscat_import/catalog/run_import.py b/src/hipscat_import/catalog/run_import.py index 9b32416c..273e1da7 100644 --- a/src/hipscat_import/catalog/run_import.py +++ b/src/hipscat_import/catalog/run_import.py @@ -50,13 +50,13 @@ def _map_pixels(args, client): total=len(futures), disable=(not args.progress_bar), ): - if future.status == "error": # pragma: no cover + if future.status == "error": # pragma: no cover some_error = True raw_histogram = np.add(raw_histogram, result) resume.write_mapping_start_key(args.tmp_path, future.key) resume.write_histogram(args.tmp_path, raw_histogram) resume.write_mapping_done_key(args.tmp_path, future.key) - if some_error: # pragma: no cover + if some_error: # pragma: no cover raise RuntimeError("Some mapping stages failed. See logs for details.") resume.set_mapping_done(args.tmp_path) return raw_histogram @@ -98,10 +98,10 @@ def _split_pixels(args, alignment_future, client): total=len(futures), disable=(not args.progress_bar), ): - if future.status == "error": # pragma: no cover + if future.status == "error": # pragma: no cover some_error = True resume.write_splitting_done_key(args.tmp_path, future.key) - if some_error: # pragma: no cover + if some_error: # pragma: no cover raise RuntimeError("Some splitting stages failed. See logs for details.") resume.set_splitting_done(args.tmp_path) @@ -143,10 +143,10 @@ def _reduce_pixels(args, destination_pixel_map, client): total=len(futures), disable=(not args.progress_bar), ): - if future.status == "error": # pragma: no cover + if future.status == "error": # pragma: no cover some_error = True resume.write_reducing_key(args.tmp_path, future.key) - if some_error: # pragma: no cover + if some_error: # pragma: no cover raise RuntimeError("Some reducing stages failed. See logs for details.") resume.set_reducing_done(args.tmp_path) @@ -215,12 +215,17 @@ def run_with_client(args, client): with tqdm( total=6, desc="Finishing", disable=not args.progress_bar ) as step_progress: - catalog_parameters = args.to_catalog_parameters() - catalog_parameters.total_rows = int(raw_histogram.sum()) - io.write_provenance_info(catalog_parameters, args.provenance_info()) + catalog_info = args.to_catalog_info(int(raw_histogram.sum())) + io.write_provenance_info( + catalog_base_dir=args.catalog_path, + dataset_info=catalog_info, + tool_args=args.provenance_info(), + ) step_progress.update(1) - io.write_catalog_info(catalog_parameters) + io.write_catalog_info( + catalog_base_dir=args.catalog_path, dataset_info=catalog_info + ) step_progress.update(1) if not args.debug_stats_only: io.write_parquet_metadata(args.catalog_path) @@ -228,7 +233,8 @@ def run_with_client(args, client): io.write_fits_map(args.catalog_path, raw_histogram) step_progress.update(1) io.write_partition_info( - catalog_parameters, destination_healpix_pixel_map=destination_pixel_map + catalog_base_dir=args.catalog_path, + destination_healpix_pixel_map=destination_pixel_map, ) step_progress.update(1) resume.clean_resume_files(args.tmp_path) diff --git a/src/hipscat_import/index/arguments.py b/src/hipscat_import/index/arguments.py index d7ba2856..88fd957c 100644 --- a/src/hipscat_import/index/arguments.py +++ b/src/hipscat_import/index/arguments.py @@ -3,7 +3,8 @@ from dataclasses import dataclass, field from typing import List, Optional -from hipscat.catalog import Catalog, CatalogParameters +from hipscat.catalog import Catalog +from hipscat.catalog.index.index_catalog_info import IndexCatalogInfo from hipscat_import.runtime_arguments import RuntimeArguments @@ -46,19 +47,19 @@ def _check_arguments(self): if self.compute_partition_size < 100_000: raise ValueError("compute_partition_size must be at least 100_000") - def to_catalog_parameters(self) -> CatalogParameters: - """Convert importing arguments into hipscat catalog parameters. - - Returns: - CatalogParameters for catalog being created. - """ - return CatalogParameters( - catalog_name=self.output_catalog_name, - catalog_type="index", - output_path=self.output_path, - ) + def to_catalog_info(self, total_rows) -> IndexCatalogInfo: + """Catalog-type-specific dataset info.""" + info = { + "catalog_name": self.output_catalog_name, + "total_rows": total_rows, + "catalog_type": "index", + "primary_catalog": str(self.input_catalog_path), + "indexing_column": self.indexing_column, + "extra_columns": self.extra_columns, + } + return IndexCatalogInfo(**info) - def additional_runtime_provenance_info(self): + def additional_runtime_provenance_info(self) -> dict: return { "input_catalog_path": str(self.input_catalog_path), "indexing_column": self.indexing_column, diff --git a/src/hipscat_import/index/run_index.py b/src/hipscat_import/index/run_index.py index d30b6f47..46a7bd3e 100644 --- a/src/hipscat_import/index/run_index.py +++ b/src/hipscat_import/index/run_index.py @@ -25,11 +25,16 @@ def run(args): ) as step_progress: # pylint: disable=duplicate-code # Very similar to /association/run_association.py - catalog_params = args.to_catalog_parameters() - catalog_params.total_rows = int(rows_written) - write_metadata.write_provenance_info(catalog_params, args.provenance_info()) + catalog_info = args.to_catalog_info(int(rows_written)) + write_metadata.write_provenance_info( + catalog_base_dir=args.catalog_path, + dataset_info=catalog_info, + tool_args=args.provenance_info(), + ) step_progress.update(1) - write_metadata.write_catalog_info(catalog_params) + write_metadata.write_catalog_info( + catalog_base_dir=args.catalog_path, dataset_info=catalog_info + ) step_progress.update(1) file_io.remove_directory(args.tmp_path, ignore_errors=True) step_progress.update(1) diff --git a/src/hipscat_import/margin_cache/margin_cache_arguments.py b/src/hipscat_import/margin_cache/margin_cache_arguments.py index 31174d20..e6412390 100644 --- a/src/hipscat_import/margin_cache/margin_cache_arguments.py +++ b/src/hipscat_import/margin_cache/margin_cache_arguments.py @@ -4,6 +4,9 @@ import healpy as hp import numpy as np from hipscat.catalog import Catalog +from hipscat.catalog.margin_cache.margin_cache_catalog_info import ( + MarginCacheCatalogInfo, +) from hipscat.io import file_io from hipscat_import.runtime_arguments import RuntimeArguments @@ -53,9 +56,30 @@ def _check_arguments(self): margin_pixel_nside = hp.order2nside(self.margin_order) - if hp.nside2resol(margin_pixel_nside, arcmin=True) * 60. < self.margin_threshold: + if ( + hp.nside2resol(margin_pixel_nside, arcmin=True) * 60.0 + < self.margin_threshold + ): # pylint: disable=line-too-long warnings.warn( "Warning: margin pixels have a smaller resolution than margin_threshold; this may lead to data loss in the margin cache." ) # pylint: enable=line-too-long + + def to_catalog_info(self, total_rows) -> MarginCacheCatalogInfo: + """Catalog-type-specific dataset info.""" + info = { + "catalog_name": self.output_catalog_name, + "total_rows": total_rows, + "catalog_type": "margin", + "primary_catalog": self.input_catalog_path, + "margin_threshold": self.margin_threshold, + } + return MarginCacheCatalogInfo(**info) + + def additional_runtime_provenance_info(self) -> dict: + return { + "input_catalog_path": str(self.input_catalog_path), + "margin_threshold": self.margin_threshold, + "margin_order": self.margin_order, + } diff --git a/tests/hipscat_import/association/test_association_argument.py b/tests/hipscat_import/association/test_association_argument.py index b14c785a..bf82055b 100644 --- a/tests/hipscat_import/association/test_association_argument.py +++ b/tests/hipscat_import/association/test_association_argument.py @@ -181,8 +181,8 @@ def test_all_required_args(tmp_path, small_sky_object_catalog): ) -def test_to_catalog_parameters(small_sky_object_catalog, tmp_path): - """Verify creation of catalog parameters for index to be created.""" +def test_to_catalog_info(small_sky_object_catalog, tmp_path): + """Verify creation of catalog parameters for association table to be created.""" args = AssociationArguments( primary_input_catalog_path=small_sky_object_catalog, primary_id_column="id", @@ -193,8 +193,9 @@ def test_to_catalog_parameters(small_sky_object_catalog, tmp_path): output_path=tmp_path, output_catalog_name="small_sky_self_join", ) - catalog_parameters = args.to_catalog_parameters() - assert catalog_parameters.catalog_name == args.output_catalog_name + catalog_info = args.to_catalog_info(total_rows=10) + assert catalog_info.catalog_name == args.output_catalog_name + assert catalog_info.total_rows == 10 def test_provenance_info(small_sky_object_catalog, tmp_path): diff --git a/tests/hipscat_import/association/test_run_association.py b/tests/hipscat_import/association/test_run_association.py index c1aa4b73..c09a31f7 100644 --- a/tests/hipscat_import/association/test_run_association.py +++ b/tests/hipscat_import/association/test_run_association.py @@ -6,6 +6,8 @@ import numpy.testing as npt import pandas as pd import pytest +from hipscat.catalog.association_catalog.association_catalog import \ + AssociationCatalog import hipscat_import.association.run_association as runner from hipscat_import.association.arguments import AssociationArguments @@ -29,9 +31,8 @@ def test_object_to_source( small_sky_object_catalog, small_sky_source_catalog, tmp_path, - assert_text_file_matches, ): - """test stuff""" + """Test creating association between object and source catalogs.""" args = AssociationArguments( primary_input_catalog_path=small_sky_object_catalog, @@ -46,40 +47,11 @@ def test_object_to_source( ) runner.run(args) - # Check that the catalog metadata file exists - expected_metadata_lines = [ - "{", - ' "catalog_name": "small_sky_association",', - ' "catalog_type": "association",', - ' "epoch": "J2000",', - ' "ra_kw": "ra",', - ' "dec_kw": "dec",', - ' "total_rows": 17161', - "}", - ] - metadata_filename = os.path.join(args.catalog_path, "catalog_info.json") - assert_text_file_matches(expected_metadata_lines, metadata_filename) - - # Check that the partition *join* info file exists - expected_lines = [ - "Norder,Dir,Npix,join_Norder,join_Dir,join_Npix,num_rows", - "0,0,11,0,0,4,50", - "0,0,11,1,0,47,2395", - "0,0,11,2,0,176,385", - "0,0,11,2,0,177,1510", - "0,0,11,2,0,178,1634", - "0,0,11,2,0,179,1773", - "0,0,11,2,0,180,655", - "0,0,11,2,0,181,903", - "0,0,11,2,0,182,1246", - "0,0,11,2,0,183,1143", - "0,0,11,2,0,184,1390", - "0,0,11,2,0,185,2942", - "0,0,11,2,0,186,452", - "0,0,11,2,0,187,683", - ] - metadata_filename = os.path.join(args.catalog_path, "partition_join_info.csv") - assert_text_file_matches(expected_lines, metadata_filename) + ## Check that the association data can be parsed as a valid association catalog. + catalog = AssociationCatalog.read_from_hipscat(args.catalog_path) + assert catalog.on_disk + assert catalog.catalog_path == args.catalog_path + assert len(catalog.get_join_pixels()) == 14 ## Test one pixel that will have 50 rows in it. output_file = os.path.join( @@ -105,15 +77,19 @@ def test_object_to_source( ids = data_frame["join_id"] assert np.logical_and(ids >= 70_000, ids < 87161).all() + catalog = AssociationCatalog.read_from_hipscat(args.catalog_path) + assert catalog.on_disk + assert catalog.catalog_path == args.catalog_path + assert len(catalog.get_join_pixels()) == 14 + @pytest.mark.dask def test_source_to_object( small_sky_object_catalog, small_sky_source_catalog, tmp_path, - assert_text_file_matches, ): - """test stuff""" + """Test creating (weirder) association between source and object catalogs.""" args = AssociationArguments( primary_input_catalog_path=small_sky_source_catalog, @@ -128,40 +104,11 @@ def test_source_to_object( ) runner.run(args) - # Check that the catalog metadata file exists - expected_metadata_lines = [ - "{", - ' "catalog_name": "small_sky_association",', - ' "catalog_type": "association",', - ' "epoch": "J2000",', - ' "ra_kw": "ra",', - ' "dec_kw": "dec",', - ' "total_rows": 17161', - "}", - ] - metadata_filename = os.path.join(args.catalog_path, "catalog_info.json") - assert_text_file_matches(expected_metadata_lines, metadata_filename) - - # Check that the partition *join* info file exists - expected_lines = [ - "Norder,Dir,Npix,join_Norder,join_Dir,join_Npix,num_rows", - "0,0,4,0,0,11,50", - "1,0,47,0,0,11,2395", - "2,0,176,0,0,11,385", - "2,0,177,0,0,11,1510", - "2,0,178,0,0,11,1634", - "2,0,179,0,0,11,1773", - "2,0,180,0,0,11,655", - "2,0,181,0,0,11,903", - "2,0,182,0,0,11,1246", - "2,0,183,0,0,11,1143", - "2,0,184,0,0,11,1390", - "2,0,185,0,0,11,2942", - "2,0,186,0,0,11,452", - "2,0,187,0,0,11,683", - ] - metadata_filename = os.path.join(args.catalog_path, "partition_join_info.csv") - assert_text_file_matches(expected_lines, metadata_filename) + ## Check that the association data can be parsed as a valid association catalog. + catalog = AssociationCatalog.read_from_hipscat(args.catalog_path) + assert catalog.on_disk + assert catalog.catalog_path == args.catalog_path + assert len(catalog.get_join_pixels()) == 14 ## Test one pixel that will have 50 rows in it. output_file = os.path.join( @@ -192,9 +139,8 @@ def test_source_to_object( def test_self_join( small_sky_object_catalog, tmp_path, - assert_text_file_matches, ): - """test stuff""" + """Test creating association between object catalog and itself.""" args = AssociationArguments( primary_input_catalog_path=small_sky_object_catalog, @@ -209,27 +155,11 @@ def test_self_join( ) runner.run(args) - # Check that the catalog metadata file exists - expected_metadata_lines = [ - "{", - ' "catalog_name": "small_sky_self_association",', - ' "catalog_type": "association",', - ' "epoch": "J2000",', - ' "ra_kw": "ra",', - ' "dec_kw": "dec",', - ' "total_rows": 131', - "}", - ] - metadata_filename = os.path.join(args.catalog_path, "catalog_info.json") - assert_text_file_matches(expected_metadata_lines, metadata_filename) - - # Check that the partition *join* info file exists - expected_lines = [ - "Norder,Dir,Npix,join_Norder,join_Dir,join_Npix,num_rows", - "0,0,11,0,0,11,131", - ] - metadata_filename = os.path.join(args.catalog_path, "partition_join_info.csv") - assert_text_file_matches(expected_lines, metadata_filename) + ## Check that the association data can be parsed as a valid association catalog. + catalog = AssociationCatalog.read_from_hipscat(args.catalog_path) + assert catalog.on_disk + assert catalog.catalog_path == args.catalog_path + assert len(catalog.get_join_pixels()) == 1 ## Test one pixel that will have 50 rows in it. output_file = os.path.join( diff --git a/tests/hipscat_import/catalog/test_argument_validation.py b/tests/hipscat_import/catalog/test_argument_validation.py index 22a77bb1..9b9694df 100644 --- a/tests/hipscat_import/catalog/test_argument_validation.py +++ b/tests/hipscat_import/catalog/test_argument_validation.py @@ -168,7 +168,7 @@ def test_catalog_type(blank_data_dir, tmp_path): ) -def test_to_catalog_parameters(blank_data_dir, tmp_path): +def test_to_catalog_info(blank_data_dir, tmp_path): """Verify creation of catalog parameters for catalog to be created.""" args = ImportArguments( output_catalog_name="catalog", @@ -177,8 +177,9 @@ def test_to_catalog_parameters(blank_data_dir, tmp_path): output_path=tmp_path, tmp_dir=tmp_path, ) - catalog_parameters = args.to_catalog_parameters() - assert catalog_parameters.catalog_name == "catalog" + catalog_info = args.to_catalog_info(total_rows=10) + assert catalog_info.catalog_name == "catalog" + assert catalog_info.total_rows == 10 def test_provenance_info(blank_data_dir, tmp_path): diff --git a/tests/hipscat_import/catalog/test_file_readers.py b/tests/hipscat_import/catalog/test_file_readers.py index 87ef60b6..ff56229d 100644 --- a/tests/hipscat_import/catalog/test_file_readers.py +++ b/tests/hipscat_import/catalog/test_file_readers.py @@ -7,7 +7,7 @@ import pyarrow as pa import pyarrow.parquet as pq import pytest -from hipscat.catalog import CatalogParameters +from hipscat.catalog.catalog import CatalogInfo from hipscat_import.catalog.file_readers import ( CsvReader, @@ -17,6 +17,19 @@ ) +# pylint: disable=redefined-outer-name +@pytest.fixture +def basic_catalog_info(): + info = { + "catalog_name": "test_catalog", + "catalog_type": "object", + "total_rows": 100, + "ra_column": "ra", + "dec_column": "dec", + } + return CatalogInfo(**info) + + def test_unknown_file_type(): """File reader factory method should fail for unknown file types""" with pytest.raises(NotImplementedError): @@ -187,7 +200,7 @@ def test_csv_reader_pipe_delimited(formats_pipe_csv, tmp_path): assert np.all(column_types == expected_column_types) -def test_csv_reader_provenance_info(tmp_path): +def test_csv_reader_provenance_info(tmp_path, basic_catalog_info): """Test that we get some provenance info and it is parseable into JSON.""" reader = CsvReader( header=None, @@ -201,10 +214,9 @@ def test_csv_reader_provenance_info(tmp_path): }, ) provenance_info = reader.provenance_info() - base_catalog_parameters = CatalogParameters( - output_path=tmp_path, catalog_name="empty" - ) - io.write_provenance_info(base_catalog_parameters, provenance_info) + catalog_base_dir = os.path.join(tmp_path, "test_catalog") + os.makedirs(catalog_base_dir) + io.write_provenance_info(catalog_base_dir, basic_catalog_info, provenance_info) def test_parquet_reader(parquet_shards_shard_44_0): @@ -226,14 +238,13 @@ def test_parquet_reader_chunked(parquet_shards_shard_44_0): assert total_chunks == 7 -def test_parquet_reader_provenance_info(tmp_path): +def test_parquet_reader_provenance_info(tmp_path, basic_catalog_info): """Test that we get some provenance info and it is parseable into JSON.""" reader = ParquetReader(chunksize=1) provenance_info = reader.provenance_info() - base_catalog_parameters = CatalogParameters( - output_path=tmp_path, catalog_name="empty" - ) - io.write_provenance_info(base_catalog_parameters, provenance_info) + catalog_base_dir = os.path.join(tmp_path, "test_catalog") + os.makedirs(catalog_base_dir) + io.write_provenance_info(catalog_base_dir, basic_catalog_info, provenance_info) def test_read_fits(formats_fits): @@ -267,11 +278,10 @@ def test_read_fits_columns(formats_fits): assert list(frame.columns) == ["id", "ra", "dec"] -def test_fits_reader_provenance_info(tmp_path): +def test_fits_reader_provenance_info(tmp_path, basic_catalog_info): """Test that we get some provenance info and it is parseable into JSON.""" reader = FitsReader() provenance_info = reader.provenance_info() - base_catalog_parameters = CatalogParameters( - output_path=tmp_path, catalog_name="empty" - ) - io.write_provenance_info(base_catalog_parameters, provenance_info) + catalog_base_dir = os.path.join(tmp_path, "test_catalog") + os.makedirs(catalog_base_dir) + io.write_provenance_info(catalog_base_dir, basic_catalog_info, provenance_info) diff --git a/tests/hipscat_import/catalog/test_run_import.py b/tests/hipscat_import/catalog/test_run_import.py index bd565996..0a30b558 100644 --- a/tests/hipscat_import/catalog/test_run_import.py +++ b/tests/hipscat_import/catalog/test_run_import.py @@ -4,6 +4,7 @@ import shutil import pytest +from hipscat.catalog.catalog import Catalog import hipscat_import.catalog.resume_files as rf import hipscat_import.catalog.run_import as runner @@ -29,7 +30,6 @@ def test_resume_dask_runner( small_sky_parts_dir, resume_dir, tmp_path, - assert_text_file_matches, assert_parquet_file_ids, ): """Test execution in the presence of some resume files.""" @@ -91,26 +91,13 @@ def test_resume_dask_runner( runner.run_with_client(args, dask_client) # Check that the catalog metadata file exists - expected_metadata_lines = [ - "{", - ' "catalog_name": "resume",', - ' "catalog_type": "object",', - ' "epoch": "J2000",', - ' "ra_kw": "ra",', - ' "dec_kw": "dec",', - ' "total_rows": 131', - "}", - ] - metadata_filename = os.path.join(args.catalog_path, "catalog_info.json") - assert_text_file_matches(expected_metadata_lines, metadata_filename) - - # Check that the partition info file exists - expected_partition_lines = [ - "Norder,Dir,Npix,num_rows", - "0,0,11,131", - ] - partition_filename = os.path.join(args.catalog_path, "partition_info.csv") - assert_text_file_matches(expected_partition_lines, partition_filename) + catalog = Catalog.read_from_hipscat(args.catalog_path) + assert catalog.on_disk + assert catalog.catalog_path == args.catalog_path + assert catalog.catalog_info.ra_column == "ra" + assert catalog.catalog_info.dec_column == "dec" + assert catalog.catalog_info.total_rows == 131 + assert len(catalog.get_pixels()) == 1 # Check that the catalog parquet file exists and contains correct object IDs output_file = os.path.join( @@ -146,8 +133,13 @@ def test_resume_dask_runner( runner.run_with_client(args, dask_client) - assert_text_file_matches(expected_metadata_lines, metadata_filename) - assert_text_file_matches(expected_partition_lines, partition_filename) + catalog = Catalog.read_from_hipscat(args.catalog_path) + assert catalog.on_disk + assert catalog.catalog_path == args.catalog_path + assert catalog.catalog_info.ra_column == "ra" + assert catalog.catalog_info.dec_column == "dec" + assert catalog.catalog_info.total_rows == 131 + assert len(catalog.get_pixels()) == 1 assert_parquet_file_ids(output_file, "id", expected_ids) @@ -156,7 +148,6 @@ def test_dask_runner( dask_client, small_sky_parts_dir, assert_parquet_file_ids, - assert_text_file_matches, tmp_path, ): """Test basic execution.""" @@ -173,26 +164,13 @@ def test_dask_runner( runner.run_with_client(args, dask_client) # Check that the catalog metadata file exists - expected_lines = [ - "{", - ' "catalog_name": "small_sky_object_catalog",', - ' "catalog_type": "object",', - ' "epoch": "J2000",', - ' "ra_kw": "ra",', - ' "dec_kw": "dec",', - ' "total_rows": 131', - "}", - ] - metadata_filename = os.path.join(args.catalog_path, "catalog_info.json") - assert_text_file_matches(expected_lines, metadata_filename) - - # Check that the partition info file exists - expected_lines = [ - "Norder,Dir,Npix,num_rows", - "0,0,11,131", - ] - metadata_filename = os.path.join(args.catalog_path, "partition_info.csv") - assert_text_file_matches(expected_lines, metadata_filename) + catalog = Catalog.read_from_hipscat(args.catalog_path) + assert catalog.on_disk + assert catalog.catalog_path == args.catalog_path + assert catalog.catalog_info.ra_column == "ra" + assert catalog.catalog_info.dec_column == "dec" + assert catalog.catalog_info.total_rows == 131 + assert len(catalog.get_pixels()) == 1 # Check that the catalog parquet file exists and contains correct object IDs output_file = os.path.join( diff --git a/tests/hipscat_import/catalog/test_run_round_trip.py b/tests/hipscat_import/catalog/test_run_round_trip.py index 174c09a9..8b050f01 100644 --- a/tests/hipscat_import/catalog/test_run_round_trip.py +++ b/tests/hipscat_import/catalog/test_run_round_trip.py @@ -11,6 +11,7 @@ import numpy.testing as npt import pandas as pd import pytest +from hipscat.catalog.catalog import Catalog import hipscat_import.catalog.run_import as runner from hipscat_import.catalog.arguments import ImportArguments @@ -21,7 +22,6 @@ def test_import_source_table( dask_client, small_sky_source_dir, - assert_text_file_matches, tmp_path, ): """Test basic execution, using a larger source file. @@ -47,39 +47,12 @@ def test_import_source_table( runner.run_with_client(args, dask_client) # Check that the catalog metadata file exists - expected_lines = [ - "{", - ' "catalog_name": "small_sky_source_catalog",', - ' "catalog_type": "source",', - ' "epoch": "J2000",', - ' "ra_kw": "source_ra",', - ' "dec_kw": "source_dec",', - ' "total_rows": 17161', - "}", - ] - metadata_filename = os.path.join(args.catalog_path, "catalog_info.json") - assert_text_file_matches(expected_lines, metadata_filename) - - # Check that the partition info file exists - expected_lines = [ - "Norder,Dir,Npix,num_rows", - "0,0,4,50", - "1,0,47,2395", - "2,0,176,385", - "2,0,177,1510", - "2,0,178,1634", - "2,0,179,1773", - "2,0,180,655", - "2,0,181,903", - "2,0,182,1246", - "2,0,183,1143", - "2,0,184,1390", - "2,0,185,2942", - "2,0,186,452", - "2,0,187,683", - ] - metadata_filename = os.path.join(args.catalog_path, "partition_info.csv") - assert_text_file_matches(expected_lines, metadata_filename) + catalog = Catalog.read_from_hipscat(args.catalog_path) + assert catalog.on_disk + assert catalog.catalog_path == args.catalog_path + assert catalog.catalog_info.ra_column == "source_ra" + assert catalog.catalog_info.dec_column == "source_dec" + assert len(catalog.get_pixels()) == 14 @pytest.mark.dask @@ -131,8 +104,8 @@ def test_import_preserve_index( ): """Test basic execution, with input with pandas metadata. - the input file is a parquet file with some pandas metadata. - this verifies that the parquet file at the end also has pandas - metadata, and the user's preferred id is retained as the index, + this verifies that the parquet file at the end also has pandas + metadata, and the user's preferred id is retained as the index, when requested. """ @@ -222,8 +195,8 @@ def test_import_multiindex( """Test basic execution, with input with pandas metadata - this is *similar* to the above test - the input file is a parquet file with a multi-level pandas index. - this verifies that the parquet file at the end also has pandas - metadata, and the user's preferred id is retained as the index, + this verifies that the parquet file at the end also has pandas + metadata, and the user's preferred id is retained as the index, when requested. """ @@ -310,7 +283,6 @@ def test_import_multiindex( def test_import_constant_healpix_order( dask_client, small_sky_parts_dir, - assert_text_file_matches, tmp_path, ): """Test basic execution. @@ -329,40 +301,12 @@ def test_import_constant_healpix_order( runner.run_with_client(args, dask_client) - # Check that the partition info file exists - all pixels at order 2! - expected_lines = [ - "Norder,Dir,Npix,num_rows", - "2,0,176,4", - "2,0,177,11", - "2,0,178,14", - "2,0,179,13", - "2,0,180,5", - "2,0,181,7", - "2,0,182,8", - "2,0,183,9", - "2,0,184,11", - "2,0,185,23", - "2,0,186,4", - "2,0,187,4", - "2,0,188,17", - "2,0,190,1", - ] - metadata_filename = os.path.join(args.catalog_path, "partition_info.csv") - assert_text_file_matches(expected_lines, metadata_filename) - # Check that the catalog metadata file exists - expected_lines = [ - "{", - ' "catalog_name": "small_sky_object_catalog",', - ' "catalog_type": "object",', - ' "epoch": "J2000",', - ' "ra_kw": "ra",', - ' "dec_kw": "dec",', - ' "total_rows": 131', - "}", - ] - metadata_filename = os.path.join(args.catalog_path, "catalog_info.json") - assert_text_file_matches(expected_lines, metadata_filename) + catalog = Catalog.read_from_hipscat(args.catalog_path) + assert catalog.on_disk + assert catalog.catalog_path == args.catalog_path + # Check that the partition info file exists - all pixels at order 2! + assert all(pixel.order == 2 for pixel in catalog.partition_info.get_healpix_pixels()) # Pick a parquet file and make sure it contains as many rows as we expect output_file = os.path.join( diff --git a/tests/hipscat_import/data/small_sky_object_catalog/catalog_info.json b/tests/hipscat_import/data/small_sky_object_catalog/catalog_info.json index 451af75e..59eaef9c 100644 --- a/tests/hipscat_import/data/small_sky_object_catalog/catalog_info.json +++ b/tests/hipscat_import/data/small_sky_object_catalog/catalog_info.json @@ -1,8 +1,8 @@ { "catalog_name": "small_sky_object_catalog", "catalog_type": "object", + "total_rows": 131, "epoch": "J2000", - "ra_kw": "ra", - "dec_kw": "dec", - "total_rows": 131 -} + "ra_column": "ra", + "dec_column": "dec" +} \ No newline at end of file diff --git a/tests/hipscat_import/data/small_sky_source_catalog/catalog_info.json b/tests/hipscat_import/data/small_sky_source_catalog/catalog_info.json index 7c819f1d..0491b5c5 100644 --- a/tests/hipscat_import/data/small_sky_source_catalog/catalog_info.json +++ b/tests/hipscat_import/data/small_sky_source_catalog/catalog_info.json @@ -1,8 +1,8 @@ { "catalog_name": "small_sky_source_catalog", "catalog_type": "source", + "total_rows": 17161, "epoch": "J2000", "ra_column": "source_ra", - "dec_column": "source_dec", - "total_rows": 17161 -} + "dec_column": "source_dec" +} \ No newline at end of file diff --git a/tests/hipscat_import/index/test_index_argument.py b/tests/hipscat_import/index/test_index_argument.py index c7032806..c1b3e125 100644 --- a/tests/hipscat_import/index/test_index_argument.py +++ b/tests/hipscat_import/index/test_index_argument.py @@ -128,7 +128,7 @@ def test_compute_partition_size(tmp_path, small_sky_object_catalog): ) -def test_to_catalog_parameters(small_sky_object_catalog, tmp_path): +def test_to_catalog_info(small_sky_object_catalog, tmp_path): """Verify creation of catalog parameters for index to be created.""" args = IndexArguments( input_catalog_path=small_sky_object_catalog, @@ -138,8 +138,9 @@ def test_to_catalog_parameters(small_sky_object_catalog, tmp_path): include_hipscat_index=True, include_order_pixel=True, ) - catalog_parameters = args.to_catalog_parameters() - assert catalog_parameters.catalog_name == args.output_catalog_name + catalog_info = args.to_catalog_info(total_rows=10) + assert catalog_info.catalog_name == args.output_catalog_name + assert catalog_info.total_rows == 10 def test_provenance_info(small_sky_object_catalog, tmp_path): diff --git a/tests/hipscat_import/index/test_run_index.py b/tests/hipscat_import/index/test_run_index.py index 0f84cd61..f9bf3d3e 100644 --- a/tests/hipscat_import/index/test_run_index.py +++ b/tests/hipscat_import/index/test_run_index.py @@ -5,6 +5,7 @@ import pyarrow as pa import pyarrow.parquet as pq import pytest +from hipscat.catalog.dataset.dataset import Dataset import hipscat_import.index.run_index as runner from hipscat_import.index.arguments import IndexArguments @@ -27,7 +28,6 @@ def test_bad_args(): def test_run_index( small_sky_object_catalog, tmp_path, - assert_text_file_matches, ): """Test appropriate metadata is written""" @@ -42,18 +42,9 @@ def test_run_index( runner.run(args) # Check that the catalog metadata file exists - expected_metadata_lines = [ - "{", - ' "catalog_name": "small_sky_object_index",', - ' "catalog_type": "index",', - ' "epoch": "J2000",', - ' "ra_kw": "ra",', - ' "dec_kw": "dec",', - ' "total_rows": 131', - "}", - ] - metadata_filename = os.path.join(args.catalog_path, "catalog_info.json") - assert_text_file_matches(expected_metadata_lines, metadata_filename) + catalog = Dataset.read_from_hipscat(args.catalog_path) + assert catalog.on_disk + assert catalog.catalog_path == args.catalog_path basic_index_parquet_schema = pa.schema( [ diff --git a/tests/hipscat_import/margin_cache/test_arguments_margin_cache.py b/tests/hipscat_import/margin_cache/test_arguments_margin_cache.py index 0d1537af..00bab144 100644 --- a/tests/hipscat_import/margin_cache/test_arguments_margin_cache.py +++ b/tests/hipscat_import/margin_cache/test_arguments_margin_cache.py @@ -7,6 +7,7 @@ # pylint: disable=protected-access + def test_empty_required(tmp_path): """*Most* required arguments are provided.""" ## Input catalog path is missing @@ -17,6 +18,7 @@ def test_empty_required(tmp_path): output_catalog_name="catalog_cache", ) + def test_margin_order_dynamic(small_sky_source_catalog, tmp_path): """Ensure we can dynamically set the margin_order""" args = MarginCacheArguments( @@ -28,6 +30,7 @@ def test_margin_order_dynamic(small_sky_source_catalog, tmp_path): assert args.margin_order == 3 + def test_margin_order_static(small_sky_source_catalog, tmp_path): """Ensure we can manually set the margin_order""" args = MarginCacheArguments( @@ -35,11 +38,12 @@ def test_margin_order_static(small_sky_source_catalog, tmp_path): input_catalog_path=small_sky_source_catalog, output_path=tmp_path, output_catalog_name="catalog_cache", - margin_order=4 + margin_order=4, ) assert args.margin_order == 4 + def test_margin_order_invalid(small_sky_source_catalog, tmp_path): """Ensure we raise an exception when margin_order is invalid""" with pytest.raises(ValueError, match="margin_order"): @@ -48,9 +52,10 @@ def test_margin_order_invalid(small_sky_source_catalog, tmp_path): input_catalog_path=small_sky_source_catalog, output_path=tmp_path, output_catalog_name="catalog_cache", - margin_order=2 + margin_order=2, ) + def test_margin_threshold_warns(small_sky_source_catalog, tmp_path): """Ensure we give a warning when margin_threshold is greater than margin_order resolution""" @@ -60,5 +65,33 @@ def test_margin_threshold_warns(small_sky_source_catalog, tmp_path): input_catalog_path=small_sky_source_catalog, output_path=tmp_path, output_catalog_name="catalog_cache", - margin_order=16 + margin_order=16, ) + + +def test_to_catalog_info(small_sky_source_catalog, tmp_path): + """Verify creation of catalog info for margin cache to be created.""" + args = MarginCacheArguments( + margin_threshold=5.0, + input_catalog_path=small_sky_source_catalog, + output_path=tmp_path, + output_catalog_name="catalog_cache", + margin_order=4, + ) + catalog_info = args.to_catalog_info(total_rows=10) + assert catalog_info.catalog_name == args.output_catalog_name + assert catalog_info.total_rows == 10 + + +def test_provenance_info(small_sky_source_catalog, tmp_path): + """Verify that provenance info includes margin-cache-specific fields.""" + args = MarginCacheArguments( + margin_threshold=5.0, + input_catalog_path=small_sky_source_catalog, + output_path=tmp_path, + output_catalog_name="catalog_cache", + margin_order=4, + ) + + runtime_args = args.provenance_info()["runtime_args"] + assert "margin_threshold" in runtime_args