diff --git a/.github/workflows/testing-and-coverage.yml b/.github/workflows/testing-and-coverage.yml index 4507157b..b497a967 100644 --- a/.github/workflows/testing-and-coverage.yml +++ b/.github/workflows/testing-and-coverage.yml @@ -27,8 +27,7 @@ jobs: run: | sudo apt-get update python -m pip install --upgrade pip - pip install . - pip install .[dev] + pip install -e .[dev] if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - name: Run unit tests with pytest run: | diff --git a/docs/guide/command_line.rst b/docs/guide/command_line.rst deleted file mode 100644 index 34c36c40..00000000 --- a/docs/guide/command_line.rst +++ /dev/null @@ -1,7 +0,0 @@ -Command Line Arguments -=============================================================================== - -TODO - -Arguments -------------------------------------------------------------------------------- \ No newline at end of file diff --git a/docs/guide/overview.rst b/docs/guide/overview.rst index 1b9bdc6d..14fd472f 100644 --- a/docs/guide/overview.rst +++ b/docs/guide/overview.rst @@ -11,5 +11,4 @@ Installation Other Topics ------------------------------------------------------------------------------- -* :doc:`command_line` * :doc:`resume` diff --git a/docs/index.rst b/docs/index.rst index d0ecaaaf..f56c5dae 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -8,7 +8,6 @@ Utility for ingesting large survey data into HiPSCat structure. :caption: Importing Catalogs guide/overview - guide/command_line guide/resume Notebooks diff --git a/docs/notebooks/intro_notebook.ipynb b/docs/notebooks/intro_notebook.ipynb index 2e7779f5..71086d5e 100644 --- a/docs/notebooks/intro_notebook.ipynb +++ b/docs/notebooks/intro_notebook.ipynb @@ -65,9 +65,9 @@ " # Calculate the cosine of each value of X\n", " z = np.cos(x)\n", " # Plot the sine wave in blue, using degrees rather than radians on the X axis\n", - " pl.plot(xdeg, y, color='blue', label='Sine wave')\n", + " pl.plot(xdeg, y, color=\"blue\", label=\"Sine wave\")\n", " # Plot the cos wave in green, using degrees rather than radians on the X axis\n", - " pl.plot(xdeg, z, color='green', label='Cosine wave')\n", + " pl.plot(xdeg, z, color=\"green\", label=\"Cosine wave\")\n", " pl.xlabel(\"Degrees\")\n", " # More sensible X axis values\n", " pl.xticks(np.arange(0, 361, 45))\n", diff --git a/pyproject.toml b/pyproject.toml index 9fbb62d9..64e6c3ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,12 +62,15 @@ write_to = "src/hipscat_import/_version.py" [tool.setuptools.package-data] hipscat_import = ["py.typed"] -[project.scripts] -hipscat-import = "hipscat_import.control:main" -hc = "hipscat_import.control:main" [tool.pytest.ini_options] timeout = 1 markers = [ "dask: mark tests as having a dask client runtime dependency", +] + +[tool.coverage.report] +omit = [ + "src/hipscat_import/_version.py", # auto-generated + "src/hipscat_import/pipeline.py", # too annoying to test ] \ No newline at end of file diff --git a/src/hipscat_import/__init__.py b/src/hipscat_import/__init__.py index 409818be..baba0a94 100644 --- a/src/hipscat_import/__init__.py +++ b/src/hipscat_import/__init__.py @@ -1,4 +1,3 @@ """All modules for hipscat-import package""" -from .control import main from .runtime_arguments import RuntimeArguments diff --git a/src/hipscat_import/__main__.py b/src/hipscat_import/__main__.py deleted file mode 100644 index a4585e85..00000000 --- a/src/hipscat_import/__main__.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Main method to enable command line execution""" - - -from hipscat_import.control import main - -if __name__ == "__main__": - main() diff --git a/src/hipscat_import/association/arguments.py b/src/hipscat_import/association/arguments.py index 28b8f852..e1b717e2 100644 --- a/src/hipscat_import/association/arguments.py +++ b/src/hipscat_import/association/arguments.py @@ -2,7 +2,9 @@ from dataclasses import dataclass -from hipscat.catalog import CatalogParameters +from hipscat.catalog.association_catalog.association_catalog import ( + AssociationCatalogInfo, +) from hipscat_import.runtime_arguments import RuntimeArguments @@ -51,19 +53,20 @@ def _check_arguments(self): if self.compute_partition_size < 100_000: raise ValueError("compute_partition_size must be at least 100_000") - def to_catalog_parameters(self) -> CatalogParameters: - """Convert importing arguments into hipscat catalog parameters. - - Returns: - CatalogParameters for catalog being created. - """ - return CatalogParameters( - catalog_name=self.output_catalog_name, - catalog_type="association", - output_path=self.output_path, - ) + def to_catalog_info(self, total_rows) -> AssociationCatalogInfo: + """Catalog-type-specific dataset info.""" + info = { + "catalog_name": self.output_catalog_name, + "catalog_type": "association", + "total_rows": total_rows, + "primary_column": self.primary_id_column, + "primary_catalog": str(self.primary_input_catalog_path), + "join_column": self.join_id_column, + "join_catalog": str(self.join_input_catalog_path), + } + return AssociationCatalogInfo(**info) - def additional_runtime_provenance_info(self): + def additional_runtime_provenance_info(self) -> dict: return { "primary_input_catalog_path": str(self.primary_input_catalog_path), "primary_id_column": self.primary_id_column, diff --git a/src/hipscat_import/association/run_association.py b/src/hipscat_import/association/run_association.py index ba505967..162fd013 100644 --- a/src/hipscat_import/association/run_association.py +++ b/src/hipscat_import/association/run_association.py @@ -12,17 +12,13 @@ from hipscat_import.association.map_reduce import map_association, reduce_association -def _validate_args(args): +def run(args): + """Run the association pipeline""" if not args: raise TypeError("args is required and should be type AssociationArguments") if not isinstance(args, AssociationArguments): raise TypeError("args must be type AssociationArguments") - -def run(args): - """Run the association pipeline""" - _validate_args(args) - with tqdm(total=1, desc="Mapping ", disable=not args.progress_bar) as step_progress: map_association(args) step_progress.update(1) @@ -40,11 +36,17 @@ def run(args): ) as step_progress: # pylint: disable=duplicate-code # Very similar to /index/run_index.py - catalog_params = args.to_catalog_parameters() - catalog_params.total_rows = int(rows_written) - write_metadata.write_provenance_info(catalog_params, args.provenance_info()) + catalog_info = args.to_catalog_info(int(rows_written)) + write_metadata.write_provenance_info( + catalog_base_dir=args.catalog_path, + dataset_info=catalog_info, + tool_args=args.provenance_info(), + ) step_progress.update(1) - write_metadata.write_catalog_info(catalog_params) + catalog_info = args.to_catalog_info(total_rows=int(rows_written)) + write_metadata.write_catalog_info( + dataset_info=catalog_info, catalog_base_dir=args.catalog_path + ) step_progress.update(1) write_metadata.write_parquet_metadata(args.catalog_path) step_progress.update(1) diff --git a/src/hipscat_import/catalog/__init__.py b/src/hipscat_import/catalog/__init__.py index 6015c90c..fe9a1e61 100644 --- a/src/hipscat_import/catalog/__init__.py +++ b/src/hipscat_import/catalog/__init__.py @@ -1,14 +1,26 @@ """All modules for importing new catalogs.""" from .arguments import ImportArguments -from .command_line_arguments import parse_command_line -from .file_readers import (CsvReader, FitsReader, InputReader, ParquetReader, - get_file_reader) +from .file_readers import ( + CsvReader, + FitsReader, + InputReader, + ParquetReader, + get_file_reader, +) from .map_reduce import map_to_pixels, reduce_pixel_shards, split_pixels -from .resume_files import (clean_resume_files, is_mapping_done, - is_reducing_done, read_histogram, read_mapping_keys, - read_reducing_keys, set_mapping_done, - set_reducing_done, write_histogram, - write_mapping_done_key, write_mapping_start_key, - write_reducing_key) -from .run_import import run, run_with_client +from .resume_files import ( + clean_resume_files, + is_mapping_done, + is_reducing_done, + read_histogram, + read_mapping_keys, + read_reducing_keys, + set_mapping_done, + set_reducing_done, + write_histogram, + write_mapping_done_key, + write_mapping_start_key, + write_reducing_key, +) +from .run_import import run diff --git a/src/hipscat_import/catalog/arguments.py b/src/hipscat_import/catalog/arguments.py index 7fa98135..6d83c9f8 100644 --- a/src/hipscat_import/catalog/arguments.py +++ b/src/hipscat_import/catalog/arguments.py @@ -6,7 +6,7 @@ from typing import Callable, List import pandas as pd -from hipscat.catalog import CatalogParameters +from hipscat.catalog.catalog import CatalogInfo from hipscat.io import FilePointer, file_io from hipscat.pixel_math import hipscat_id @@ -96,8 +96,8 @@ def _check_arguments(self): check_healpix_order_range( self.highest_healpix_order, "highest_healpix_order" ) - if not 100 <= self.pixel_threshold <= 10_000_000: - raise ValueError("pixel_threshold should be between 100 and 10,000,000") + if not 100 <= self.pixel_threshold <= 1_000_000_000: + raise ValueError("pixel_threshold should be between 100 and 1,000,000,000") self.mapping_healpix_order = self.highest_healpix_order if self.catalog_type not in ("source", "object"): @@ -140,21 +140,19 @@ def _check_arguments(self): if not self.filter_function: self.filter_function = passthrough_filter_function - def to_catalog_parameters(self) -> CatalogParameters: - """Convert importing arguments into hipscat catalog parameters. - Returns: - CatalogParameters for catalog being created. - """ - return CatalogParameters( - catalog_name=self.output_catalog_name, - catalog_type=self.catalog_type, - output_path=self.output_path, - epoch=self.epoch, - ra_column=self.ra_column, - dec_column=self.dec_column, - ) + def to_catalog_info(self, total_rows) -> CatalogInfo: + """Catalog-type-specific dataset info.""" + info = { + "catalog_name": self.output_catalog_name, + "catalog_type": self.catalog_type, + "total_rows": total_rows, + "epoch": self.epoch, + "ra_column": self.ra_column, + "dec_column": self.dec_column, + } + return CatalogInfo(**info) - def additional_runtime_provenance_info(self): + def additional_runtime_provenance_info(self) -> dict: return { "catalog_name": self.output_catalog_name, "epoch": self.epoch, @@ -171,7 +169,9 @@ def additional_runtime_provenance_info(self): "pixel_threshold": self.pixel_threshold, "mapping_healpix_order": self.mapping_healpix_order, "debug_stats_only": self.debug_stats_only, - "file_reader_info": self.file_reader.provenance_info(), + "file_reader_info": self.file_reader.provenance_info() + if self.file_reader is not None + else {}, } @@ -180,6 +180,7 @@ def check_healpix_order_range( ): """Helper method to heck if the `order` is within the range determined by the `lower_bound` and `upper_bound`, inclusive. + Args: order (int): healpix order to check field_name (str): field name to use in the error message diff --git a/src/hipscat_import/catalog/command_line_arguments.py b/src/hipscat_import/catalog/command_line_arguments.py deleted file mode 100644 index 226e7096..00000000 --- a/src/hipscat_import/catalog/command_line_arguments.py +++ /dev/null @@ -1,246 +0,0 @@ -"""Parse import arguments from command line""" - -import argparse - -from hipscat_import.catalog.arguments import ImportArguments -from hipscat_import.catalog.file_readers import get_file_reader - - -def parse_command_line(cl_args): - """Parse arguments from the command line""" - - parser = argparse.ArgumentParser( - prog="LSD2 Partitioner", - description="Instantiate a partitioned catalog from unpartitioned sources", - ) - - # =========== INPUT ARGUMENTS =========== - group = parser.add_argument_group("INPUT") - group.add_argument( - "-c", - "--catalog_name", - help="short name for the catalog that will be used for the output directory", - default=None, - type=str, - ) - group.add_argument( - "-i", - "--input_path", - help="path prefix for unpartitioned input files", - default=None, - type=str, - ) - group.add_argument( - "-fmt", - "--input_format", - help="file format for unpartitioned input files", - default="parquet", - type=str, - ) - group.add_argument( - "--input_file_list", - help="explicit list of input files, comma-separated", - default="", - type=str, - ) - - # =========== READER ARGUMENTS =========== - group = parser.add_argument_group("READER") - group.add_argument( - "--schema_file", - help="parquet file that contains field names and types to be used when reading a CSV file", - default=None, - type=str, - ) - group.add_argument( - "--header_rows", - help="number of rows of header in a CSV. if 0, there are only data rows", - default=1, - type=int, - ) - group.add_argument( - "--column_names", - help="comma-separated list of names of columns. " - "used in the absence of a header row or to rename columns", - default=None, - type=str, - ) - group.add_argument( - "--separator", - help="field delimiter in text or CSV file", - default=",", - type=str, - ) - group.add_argument( - "--chunksize", - help="number of input rows to process in a chunk. recommend using" - " a smaller chunk size for input with wider rows", - default=500_000, - type=int, - ) - # =========== INPUT COLUMNS =========== - group = parser.add_argument_group( - "INPUT COLUMNS", - """Column names in the input source that - correspond to spatial attributes used in partitioning""", - ) - group.add_argument( - "-ra", - "--ra_column", - help="column name for the ra (rate of ascension)", - default="ra", - type=str, - ) - group.add_argument( - "-dec", - "--dec_column", - help="column name for the dec (declination)", - default="dec", - type=str, - ) - group.add_argument( - "-id", - "--id_column", - help="column name for the object id", - default="id", - type=str, - ) - # =========== OUTPUT ARGUMENTS =========== - group = parser.add_argument_group("OUTPUT") - group.add_argument( - "-o", - "--output_path", - help="path prefix for partitioned output and metadata files", - default=None, - type=str, - ) - group.add_argument( - "--add_hipscat_index", - help="Option to generate the _hipscat_index column " - "a spatially aware index for read and join optimization", - action="store_true", - ) - group.add_argument( - "--overwrite", - help="if set, any existing catalog data will be overwritten", - action="store_true", - ) - group.add_argument( - "--no_overwrite", - help="if set, the pipeline will exit if existing output is found", - dest="overwrite", - action="store_false", - ) - group.add_argument( - "--resume", - help="if set, the pipeline will try to resume from a previous failed pipeline progress", - action="store_true", - ) - group.add_argument( - "--no_resume", - help="if set, the pipeline will exit if existing intermediate files are found", - dest="resume", - action="store_false", - ) - - # =========== STATS ARGUMENTS =========== - group = parser.add_argument_group("STATS") - group.add_argument( - "-ho", - "--highest_healpix_order", - help="the most dense healpix order (7-10 is a good range for this)", - default=10, - type=int, - ) - group.add_argument( - "-pt", - "--pixel_threshold", - help="maximum objects allowed in a single pixel", - default=1_000_000, - type=int, - ) - group.add_argument( - "--debug_stats_only", - help="""DEBUGGING FLAG - - if set, the pipeline will only fetch statistics about the origin data - and will not generate partitioned output""", - action="store_true", - ) - group.add_argument( - "--no_debug_stats_only", - help="DEBUGGING FLAG - if set, the pipeline will generate partitioned output", - dest="debug_stats_only", - action="store_false", - ) - # =========== EXECUTION ARGUMENTS =========== - group = parser.add_argument_group("EXECUTION") - group.add_argument( - "--progress_bar", - help="should a progress bar be displayed?", - default=True, - action="store_true", - ) - group.add_argument( - "--no_progress_bar", - help="should a progress bar be displayed?", - dest="progress_bar", - action="store_false", - ) - group.add_argument( - "--tmp_dir", - help="directory for storing temporary parquet files", - default=None, - type=str, - ) - group.add_argument( - "-dt", - "--dask_tmp", - help="directory for storing temporary files generated by dask engine", - default=None, - type=str, - ) - group.add_argument( - "--dask_n_workers", - help="the number of dask workers available", - default=1, - type=int, - ) - group.add_argument( - "--dask_threads_per_worker", - help="the number of threads per dask worker", - default=1, - type=int, - ) - - args = parser.parse_args(cl_args) - - return ImportArguments( - output_catalog_name=args.catalog_name, - input_path=args.input_path, - input_format=args.input_format, - input_file_list=( - args.input_file_list.split(",") if args.input_file_list else None - ), - ra_column=args.ra_column, - dec_column=args.dec_column, - id_column=args.id_column, - add_hipscat_index=args.add_hipscat_index, - output_path=args.output_path, - overwrite=args.overwrite, - highest_healpix_order=args.highest_healpix_order, - pixel_threshold=args.pixel_threshold, - debug_stats_only=args.debug_stats_only, - file_reader=get_file_reader( - args.input_format, - chunksize=args.chunksize, - header=args.header_rows if args.header_rows != 0 else None, - schema_file=args.schema_file, - column_names=(args.column_names.split(",") if args.column_names else None), - separator=args.separator, - ), - tmp_dir=args.tmp_dir, - progress_bar=args.progress_bar, - dask_tmp=args.dask_tmp, - dask_n_workers=args.dask_n_workers, - dask_threads_per_worker=args.dask_threads_per_worker, - ) diff --git a/src/hipscat_import/catalog/map_reduce.py b/src/hipscat_import/catalog/map_reduce.py index cc75dc70..34fa30a5 100644 --- a/src/hipscat_import/catalog/map_reduce.py +++ b/src/hipscat_import/catalog/map_reduce.py @@ -98,7 +98,7 @@ def map_to_pixels( Args: input_file (FilePointer): file to read for catalog data. - file_reader (hipscat_import.catalog.file_readers.InputReader): instance of input + file_reader (hipscat_import.catalog.file_readers.InputReader): instance of input reader that specifies arguments necessary for reading from the input file. shard_suffix (str): unique counter for this input file, used when creating intermediate files @@ -137,7 +137,7 @@ def split_pixels( Args: input_file (FilePointer): file to read for catalog data. - file_reader (hipscat_import.catalog.file_readers.InputReader): instance + file_reader (hipscat_import.catalog.file_readers.InputReader): instance of input reader that specifies arguments necessary for reading from the input file. shard_suffix (str): unique counter for this input file, used when creating intermediate files diff --git a/src/hipscat_import/catalog/run_import.py b/src/hipscat_import/catalog/run_import.py index 9b32416c..7155c2e4 100644 --- a/src/hipscat_import/catalog/run_import.py +++ b/src/hipscat_import/catalog/run_import.py @@ -7,7 +7,7 @@ import hipscat.io.write_metadata as io import numpy as np -from dask.distributed import Client, as_completed +from dask.distributed import as_completed from hipscat import pixel_math from tqdm import tqdm @@ -50,13 +50,14 @@ def _map_pixels(args, client): total=len(futures), disable=(not args.progress_bar), ): - if future.status == "error": # pragma: no cover + if future.status == "error": # pragma: no cover some_error = True + continue raw_histogram = np.add(raw_histogram, result) resume.write_mapping_start_key(args.tmp_path, future.key) resume.write_histogram(args.tmp_path, raw_histogram) resume.write_mapping_done_key(args.tmp_path, future.key) - if some_error: # pragma: no cover + if some_error: # pragma: no cover raise RuntimeError("Some mapping stages failed. See logs for details.") resume.set_mapping_done(args.tmp_path) return raw_histogram @@ -98,10 +99,11 @@ def _split_pixels(args, alignment_future, client): total=len(futures), disable=(not args.progress_bar), ): - if future.status == "error": # pragma: no cover + if future.status == "error": # pragma: no cover some_error = True + continue resume.write_splitting_done_key(args.tmp_path, future.key) - if some_error: # pragma: no cover + if some_error: # pragma: no cover raise RuntimeError("Some splitting stages failed. See logs for details.") resume.set_splitting_done(args.tmp_path) @@ -143,38 +145,21 @@ def _reduce_pixels(args, destination_pixel_map, client): total=len(futures), disable=(not args.progress_bar), ): - if future.status == "error": # pragma: no cover + if future.status == "error": # pragma: no cover some_error = True + continue resume.write_reducing_key(args.tmp_path, future.key) - if some_error: # pragma: no cover + if some_error: # pragma: no cover raise RuntimeError("Some reducing stages failed. See logs for details.") resume.set_reducing_done(args.tmp_path) -def _validate_args(args): +def run(args, client): + """Run catalog creation pipeline.""" if not args: raise ValueError("args is required and should be type ImportArguments") if not isinstance(args, ImportArguments): raise ValueError("args must be type ImportArguments") - - -def run(args): - """Importer that creates a dask client from the arguments""" - _validate_args(args) - - # pylint: disable=duplicate-code - with Client( - local_directory=args.dask_tmp, - n_workers=args.dask_n_workers, - threads_per_worker=args.dask_threads_per_worker, - ) as client: # pragma: no cover - run_with_client(args, client) - # pylint: enable=duplicate-code - - -def run_with_client(args, client): - """Importer, where the client context may out-live the runner""" - _validate_args(args) raw_histogram = _map_pixels(args, client) with tqdm( @@ -215,12 +200,17 @@ def run_with_client(args, client): with tqdm( total=6, desc="Finishing", disable=not args.progress_bar ) as step_progress: - catalog_parameters = args.to_catalog_parameters() - catalog_parameters.total_rows = int(raw_histogram.sum()) - io.write_provenance_info(catalog_parameters, args.provenance_info()) + catalog_info = args.to_catalog_info(int(raw_histogram.sum())) + io.write_provenance_info( + catalog_base_dir=args.catalog_path, + dataset_info=catalog_info, + tool_args=args.provenance_info(), + ) step_progress.update(1) - io.write_catalog_info(catalog_parameters) + io.write_catalog_info( + catalog_base_dir=args.catalog_path, dataset_info=catalog_info + ) step_progress.update(1) if not args.debug_stats_only: io.write_parquet_metadata(args.catalog_path) @@ -228,7 +218,8 @@ def run_with_client(args, client): io.write_fits_map(args.catalog_path, raw_histogram) step_progress.update(1) io.write_partition_info( - catalog_parameters, destination_healpix_pixel_map=destination_pixel_map + catalog_base_dir=args.catalog_path, + destination_healpix_pixel_map=destination_pixel_map, ) step_progress.update(1) resume.clean_resume_files(args.tmp_path) diff --git a/src/hipscat_import/control.py b/src/hipscat_import/control.py deleted file mode 100644 index 5a46e808..00000000 --- a/src/hipscat_import/control.py +++ /dev/null @@ -1,10 +0,0 @@ -"""Flow control and scripting entry points.""" -import sys - -import hipscat_import.catalog.run_import as runner -from hipscat_import.catalog.command_line_arguments import parse_command_line - - -def main(): - """Wrapper of main for setuptools.""" - runner.run(parse_command_line(sys.argv[1:])) diff --git a/src/hipscat_import/index/arguments.py b/src/hipscat_import/index/arguments.py index d7ba2856..88fd957c 100644 --- a/src/hipscat_import/index/arguments.py +++ b/src/hipscat_import/index/arguments.py @@ -3,7 +3,8 @@ from dataclasses import dataclass, field from typing import List, Optional -from hipscat.catalog import Catalog, CatalogParameters +from hipscat.catalog import Catalog +from hipscat.catalog.index.index_catalog_info import IndexCatalogInfo from hipscat_import.runtime_arguments import RuntimeArguments @@ -46,19 +47,19 @@ def _check_arguments(self): if self.compute_partition_size < 100_000: raise ValueError("compute_partition_size must be at least 100_000") - def to_catalog_parameters(self) -> CatalogParameters: - """Convert importing arguments into hipscat catalog parameters. - - Returns: - CatalogParameters for catalog being created. - """ - return CatalogParameters( - catalog_name=self.output_catalog_name, - catalog_type="index", - output_path=self.output_path, - ) + def to_catalog_info(self, total_rows) -> IndexCatalogInfo: + """Catalog-type-specific dataset info.""" + info = { + "catalog_name": self.output_catalog_name, + "total_rows": total_rows, + "catalog_type": "index", + "primary_catalog": str(self.input_catalog_path), + "indexing_column": self.indexing_column, + "extra_columns": self.extra_columns, + } + return IndexCatalogInfo(**info) - def additional_runtime_provenance_info(self): + def additional_runtime_provenance_info(self) -> dict: return { "input_catalog_path": str(self.input_catalog_path), "indexing_column": self.indexing_column, diff --git a/src/hipscat_import/index/run_index.py b/src/hipscat_import/index/run_index.py index d30b6f47..e1912538 100644 --- a/src/hipscat_import/index/run_index.py +++ b/src/hipscat_import/index/run_index.py @@ -7,16 +7,12 @@ from hipscat_import.index.arguments import IndexArguments -def _validate_args(args): +def run(args): + """Run index creation pipeline.""" if not args: raise TypeError("args is required and should be type IndexArguments") if not isinstance(args, IndexArguments): raise TypeError("args must be type IndexArguments") - - -def run(args): - """Importer, where the client context may out-live the runner""" - _validate_args(args) rows_written = mr.create_index(args) # All done - write out the metadata @@ -25,11 +21,16 @@ def run(args): ) as step_progress: # pylint: disable=duplicate-code # Very similar to /association/run_association.py - catalog_params = args.to_catalog_parameters() - catalog_params.total_rows = int(rows_written) - write_metadata.write_provenance_info(catalog_params, args.provenance_info()) + catalog_info = args.to_catalog_info(int(rows_written)) + write_metadata.write_provenance_info( + catalog_base_dir=args.catalog_path, + dataset_info=catalog_info, + tool_args=args.provenance_info(), + ) step_progress.update(1) - write_metadata.write_catalog_info(catalog_params) + write_metadata.write_catalog_info( + catalog_base_dir=args.catalog_path, dataset_info=catalog_info + ) step_progress.update(1) file_io.remove_directory(args.tmp_path, ignore_errors=True) step_progress.update(1) diff --git a/src/hipscat_import/margin_cache/__init__.py b/src/hipscat_import/margin_cache/__init__.py index 940d08b5..93f63fe3 100644 --- a/src/hipscat_import/margin_cache/__init__.py +++ b/src/hipscat_import/margin_cache/__init__.py @@ -1,4 +1,3 @@ """All modules for generating margin caches.""" -from .margin_cache import (generate_margin_cache, - generate_margin_cache_with_client) +from .margin_cache import generate_margin_cache from .margin_cache_arguments import MarginCacheArguments diff --git a/src/hipscat_import/margin_cache/margin_cache.py b/src/hipscat_import/margin_cache/margin_cache.py index ab67345f..ddb2d0a4 100644 --- a/src/hipscat_import/margin_cache/margin_cache.py +++ b/src/hipscat_import/margin_cache/margin_cache.py @@ -1,5 +1,5 @@ import pandas as pd -from dask.distributed import Client, as_completed +from dask.distributed import as_completed from hipscat import pixel_math from hipscat.io import file_io, paths from tqdm import tqdm @@ -8,6 +8,7 @@ # pylint: disable=too-many-locals,too-many-arguments + def _find_partition_margin_pixel_pairs(stats, margin_order): """Creates a DataFrame filled with many-to-many connections between the catalog partition pixels and the margin pixels at `margin_order`. @@ -31,30 +32,28 @@ def _find_partition_margin_pixel_pairs(stats, margin_order): margin_pairs_df = pd.DataFrame( zip(norders, part_pix, margin_pix), - columns=["partition_order", "partition_pixel", "margin_pixel"] + columns=["partition_order", "partition_pixel", "margin_pixel"], ) return margin_pairs_df + def _create_margin_directory(stats, output_path): """Creates directories for all the catalog partitions.""" for healpixel in stats: order = healpixel.order pix = healpixel.pixel - destination_dir = paths.pixel_directory( - output_path, order, pix - ) + destination_dir = paths.pixel_directory(output_path, order, pix) file_io.make_directory(destination_dir, exist_ok=True) + def _map_to_margin_shards(client, args, partition_pixels, margin_pairs): """Create all the jobs for mapping partition files into the margin cache.""" futures = [] mp_future = client.scatter(margin_pairs, broadcast=True) for pix in partition_pixels: partition_file = paths.pixel_catalog_file( - args.input_catalog_path, - pix.order, - pix.pixel + args.input_catalog_path, pix.order, pix.pixel ) futures.append( client.submit( @@ -97,50 +96,26 @@ def _reduce_margin_shards(client, args, partition_pixels): ): ... -def generate_margin_cache(args): +def generate_margin_cache(args, client): """Generate a margin cache for a given input catalog. The input catalog must be in hipscat format. - This method will handle the creation of the `dask.distributed` client - based on the `dask_tmp`, `dask_n_workers`, and `dask_threads_per_worker` - values of the `MarginCacheArguments` object. Args: args (MarginCacheArguments): A valid `MarginCacheArguments` object. - """ - # pylint: disable=duplicate-code - with Client( - local_directory=args.dask_tmp, - n_workers=args.dask_n_workers, - threads_per_worker=args.dask_threads_per_worker, - ) as client: # pragma: no cover - generate_margin_cache_with_client( - client, - args - ) - # pylint: enable=duplicate-code - -def generate_margin_cache_with_client(client, args): - """Generate a margin cache for a given input catalog. - The input catalog must be in hipscat format. - Args: client (dask.distributed.Client): A dask distributed client object. - args (MarginCacheArguments): A valid `MarginCacheArguments` object. """ # determine which order to generate margin pixels for partition_stats = args.catalog.partition_info.get_healpix_pixels() margin_pairs = _find_partition_margin_pixel_pairs( - partition_stats, - args.margin_order + partition_stats, args.margin_order ) # arcsec to degree conversion # TODO: remove this once hipscat uses arcsec for calculation - args.margin_threshold = args.margin_threshold / 3600. + args.margin_threshold = args.margin_threshold / 3600.0 - _create_margin_directory( - partition_stats, args.catalog_path - ) + _create_margin_directory(partition_stats, args.catalog_path) _map_to_margin_shards( client=client, diff --git a/src/hipscat_import/margin_cache/margin_cache_arguments.py b/src/hipscat_import/margin_cache/margin_cache_arguments.py index 31174d20..e6412390 100644 --- a/src/hipscat_import/margin_cache/margin_cache_arguments.py +++ b/src/hipscat_import/margin_cache/margin_cache_arguments.py @@ -4,6 +4,9 @@ import healpy as hp import numpy as np from hipscat.catalog import Catalog +from hipscat.catalog.margin_cache.margin_cache_catalog_info import ( + MarginCacheCatalogInfo, +) from hipscat.io import file_io from hipscat_import.runtime_arguments import RuntimeArguments @@ -53,9 +56,30 @@ def _check_arguments(self): margin_pixel_nside = hp.order2nside(self.margin_order) - if hp.nside2resol(margin_pixel_nside, arcmin=True) * 60. < self.margin_threshold: + if ( + hp.nside2resol(margin_pixel_nside, arcmin=True) * 60.0 + < self.margin_threshold + ): # pylint: disable=line-too-long warnings.warn( "Warning: margin pixels have a smaller resolution than margin_threshold; this may lead to data loss in the margin cache." ) # pylint: enable=line-too-long + + def to_catalog_info(self, total_rows) -> MarginCacheCatalogInfo: + """Catalog-type-specific dataset info.""" + info = { + "catalog_name": self.output_catalog_name, + "total_rows": total_rows, + "catalog_type": "margin", + "primary_catalog": self.input_catalog_path, + "margin_threshold": self.margin_threshold, + } + return MarginCacheCatalogInfo(**info) + + def additional_runtime_provenance_info(self) -> dict: + return { + "input_catalog_path": str(self.input_catalog_path), + "margin_threshold": self.margin_threshold, + "margin_order": self.margin_order, + } diff --git a/src/hipscat_import/margin_cache/margin_cache_map_reduce.py b/src/hipscat_import/margin_cache/margin_cache_map_reduce.py index 000395a9..63336d10 100644 --- a/src/hipscat_import/margin_cache/margin_cache_map_reduce.py +++ b/src/hipscat_import/margin_cache/margin_cache_map_reduce.py @@ -7,6 +7,7 @@ # pylint: disable=too-many-locals,too-many-arguments + def map_pixel_shards( partition_file, margin_pairs, @@ -14,7 +15,7 @@ def map_pixel_shards( output_path, margin_order, ra_column, - dec_column + dec_column, ): """Creates margin cache shards from a source partition file.""" data = pd.read_parquet(partition_file) @@ -39,7 +40,10 @@ def map_pixel_shards( dec_column=dec_column, ) -def _to_pixel_shard(data, margin_threshold, output_path, margin_order, ra_column, dec_column): + +def _to_pixel_shard( + data, margin_threshold, output_path, margin_order, ra_column, dec_column +): """Do boundary checking for the cached partition and then output remaining data.""" order, pix = data["partition_order"].iloc[0], data["partition_pixel"].iloc[0] source_order, source_pix = data["Norder"].iloc[0], data["Npix"].iloc[0] @@ -61,9 +65,7 @@ def _to_pixel_shard(data, margin_threshold, output_path, margin_order, ra_column ) else: data["margin_check"] = pixel_math.check_margin_bounds( - data[ra_column].values, - data[dec_column].values, - bounding_polygons + data[ra_column].values, data[dec_column].values, bounding_polygons ) # pylint: disable-next=singleton-comparison @@ -80,22 +82,23 @@ def _to_pixel_shard(data, margin_threshold, output_path, margin_order, ra_column file_io.make_directory(shard_dir, exist_ok=True) - shard_path = paths.pixel_catalog_file( - partition_dir, source_order, source_pix - ) + shard_path = paths.pixel_catalog_file(partition_dir, source_order, source_pix) - final_df = margin_data.drop(columns=[ - "partition_order", - "partition_pixel", - "margin_check", - "margin_pixel" - ]) + final_df = margin_data.drop( + columns=[ + "partition_order", + "partition_pixel", + "margin_check", + "margin_pixel", + ] + ) if is_polar: final_df = final_df.drop(columns=["is_trunc"]) final_df.to_parquet(shard_path) + def _margin_filter_polar( data, order, @@ -104,13 +107,11 @@ def _margin_filter_polar( margin_threshold, ra_column, dec_column, - bounding_polygons + bounding_polygons, ): """Filter out margin data around the poles.""" trunc_pix = pixel_math.get_truncated_margin_pixels( - order=order, - pix=pix, - margin_order=margin_order + order=order, pix=pix, margin_order=margin_order ) data["is_trunc"] = np.isin(data["margin_pixel"], trunc_pix) @@ -124,14 +125,12 @@ def _margin_filter_polar( order, pix, margin_order, - margin_threshold + margin_threshold, ) data.loc[data["is_trunc"] == True, "margin_check"] = trunc_check no_trunc_check = pixel_math.check_margin_bounds( - other_data[ra_column].values, - other_data[dec_column].values, - bounding_polygons + other_data[ra_column].values, other_data[dec_column].values, bounding_polygons ) data.loc[data["is_trunc"] == False, "margin_check"] = no_trunc_check # pylint: enable=singleton-comparison diff --git a/src/hipscat_import/pipeline.py b/src/hipscat_import/pipeline.py new file mode 100644 index 00000000..eb0f47c9 --- /dev/null +++ b/src/hipscat_import/pipeline.py @@ -0,0 +1,83 @@ +"""Flow control and pipeline entry points.""" +import smtplib +from email.message import EmailMessage + +from dask.distributed import Client + +import hipscat_import.association.run_association as association_runner +import hipscat_import.catalog.run_import as catalog_runner +import hipscat_import.index.run_index as index_runner +import hipscat_import.margin_cache.margin_cache as margin_runner +from hipscat_import.association.arguments import AssociationArguments +from hipscat_import.catalog.arguments import ImportArguments +from hipscat_import.index.arguments import IndexArguments +from hipscat_import.margin_cache.margin_cache_arguments import MarginCacheArguments +from hipscat_import.runtime_arguments import RuntimeArguments + +# pragma: no cover + + +def pipeline(args: RuntimeArguments): + """Pipeline that creates its own client from the provided runtime arguments""" + with Client( + local_directory=args.dask_tmp, + n_workers=args.dask_n_workers, + threads_per_worker=args.dask_threads_per_worker, + ) as client: + pipeline_with_client(args, client) + + +def pipeline_with_client(args: RuntimeArguments, client: Client): + """Pipeline that is run using an existing client. + + This can be useful in tests, or when a dask client requires some more complex + configuraion. + """ + try: + if not args: + raise ValueError("args is required and should be subclass of RuntimeArguments") + + if isinstance(args, ImportArguments): + catalog_runner.run(args, client) + elif isinstance(args, AssociationArguments): + association_runner.run(args) + elif isinstance(args, IndexArguments): + index_runner.run(args) + elif isinstance(args, MarginCacheArguments): + margin_runner.generate_margin_cache(args, client) + else: + raise ValueError("unknown args type") + except Exception as exception: # pylint: disable=broad-exception-caught + _send_failure_email(args, exception) + else: + _send_success_email(args) + + +def _send_failure_email(args: RuntimeArguments, exception: Exception): + if not args.completion_email_address: + raise exception + message = EmailMessage() + message["Subject"] = "hipscat-import failure." + message["To"] = args.completion_email_address + message.set_content(f"failed with message:\n{exception}") + + _send_email(message) + + +def _send_success_email(args): + if not args.completion_email_address: + return + message = EmailMessage() + message["Subject"] = "hipscat-import success." + message["To"] = args.completion_email_address + message.set_content(f"output_catalog_name: {args.output_catalog_name}") + + _send_email(message) + + +def _send_email(message: EmailMessage): + message["From"] = "updates@lsdb.io" + + # Send the message via our own SMTP server. + with smtplib.SMTP("localhost") as server: + server.send_message(message) diff --git a/src/hipscat_import/runtime_arguments.py b/src/hipscat_import/runtime_arguments.py index c7253450..a2bce52a 100644 --- a/src/hipscat_import/runtime_arguments.py +++ b/src/hipscat_import/runtime_arguments.py @@ -36,6 +36,10 @@ class RuntimeArguments: dask_threads_per_worker: int = 1 """number of threads per dask worker""" + completion_email_address: str = "" + """if provided, send an email to the indicated email address once the + import pipeline has complete.""" + catalog_path = "" """constructed output path for the catalog that will be something like /""" diff --git a/tests/hipscat_import/association/test_association_argument.py b/tests/hipscat_import/association/test_association_argument.py index b14c785a..d8e1be4b 100644 --- a/tests/hipscat_import/association/test_association_argument.py +++ b/tests/hipscat_import/association/test_association_argument.py @@ -1,4 +1,4 @@ -"""Tests of argument validation, in the absense of command line parsing""" +"""Tests of argument validation""" import pytest @@ -181,8 +181,8 @@ def test_all_required_args(tmp_path, small_sky_object_catalog): ) -def test_to_catalog_parameters(small_sky_object_catalog, tmp_path): - """Verify creation of catalog parameters for index to be created.""" +def test_to_catalog_info(small_sky_object_catalog, tmp_path): + """Verify creation of catalog parameters for association table to be created.""" args = AssociationArguments( primary_input_catalog_path=small_sky_object_catalog, primary_id_column="id", @@ -193,8 +193,9 @@ def test_to_catalog_parameters(small_sky_object_catalog, tmp_path): output_path=tmp_path, output_catalog_name="small_sky_self_join", ) - catalog_parameters = args.to_catalog_parameters() - assert catalog_parameters.catalog_name == args.output_catalog_name + catalog_info = args.to_catalog_info(total_rows=10) + assert catalog_info.catalog_name == args.output_catalog_name + assert catalog_info.total_rows == 10 def test_provenance_info(small_sky_object_catalog, tmp_path): diff --git a/tests/hipscat_import/association/test_association_map_reduce.py b/tests/hipscat_import/association/test_association_map_reduce.py index 88de77bb..8d310869 100644 --- a/tests/hipscat_import/association/test_association_map_reduce.py +++ b/tests/hipscat_import/association/test_association_map_reduce.py @@ -33,7 +33,7 @@ def test_map_association( ) subset_catalog_path = args.catalog_path - runner.run_with_client(args, dask_client) + runner.run(args, dask_client) with open( os.path.join(subset_catalog_path, "catalog_info.json"), "r", encoding="utf-8" diff --git a/tests/hipscat_import/association/test_run_association.py b/tests/hipscat_import/association/test_run_association.py index c1aa4b73..628b6c6b 100644 --- a/tests/hipscat_import/association/test_run_association.py +++ b/tests/hipscat_import/association/test_run_association.py @@ -6,6 +6,7 @@ import numpy.testing as npt import pandas as pd import pytest +from hipscat.catalog.association_catalog.association_catalog import AssociationCatalog import hipscat_import.association.run_association as runner from hipscat_import.association.arguments import AssociationArguments @@ -29,9 +30,8 @@ def test_object_to_source( small_sky_object_catalog, small_sky_source_catalog, tmp_path, - assert_text_file_matches, ): - """test stuff""" + """Test creating association between object and source catalogs.""" args = AssociationArguments( primary_input_catalog_path=small_sky_object_catalog, @@ -46,40 +46,11 @@ def test_object_to_source( ) runner.run(args) - # Check that the catalog metadata file exists - expected_metadata_lines = [ - "{", - ' "catalog_name": "small_sky_association",', - ' "catalog_type": "association",', - ' "epoch": "J2000",', - ' "ra_kw": "ra",', - ' "dec_kw": "dec",', - ' "total_rows": 17161', - "}", - ] - metadata_filename = os.path.join(args.catalog_path, "catalog_info.json") - assert_text_file_matches(expected_metadata_lines, metadata_filename) - - # Check that the partition *join* info file exists - expected_lines = [ - "Norder,Dir,Npix,join_Norder,join_Dir,join_Npix,num_rows", - "0,0,11,0,0,4,50", - "0,0,11,1,0,47,2395", - "0,0,11,2,0,176,385", - "0,0,11,2,0,177,1510", - "0,0,11,2,0,178,1634", - "0,0,11,2,0,179,1773", - "0,0,11,2,0,180,655", - "0,0,11,2,0,181,903", - "0,0,11,2,0,182,1246", - "0,0,11,2,0,183,1143", - "0,0,11,2,0,184,1390", - "0,0,11,2,0,185,2942", - "0,0,11,2,0,186,452", - "0,0,11,2,0,187,683", - ] - metadata_filename = os.path.join(args.catalog_path, "partition_join_info.csv") - assert_text_file_matches(expected_lines, metadata_filename) + ## Check that the association data can be parsed as a valid association catalog. + catalog = AssociationCatalog.read_from_hipscat(args.catalog_path) + assert catalog.on_disk + assert catalog.catalog_path == args.catalog_path + assert len(catalog.get_join_pixels()) == 14 ## Test one pixel that will have 50 rows in it. output_file = os.path.join( @@ -105,15 +76,19 @@ def test_object_to_source( ids = data_frame["join_id"] assert np.logical_and(ids >= 70_000, ids < 87161).all() + catalog = AssociationCatalog.read_from_hipscat(args.catalog_path) + assert catalog.on_disk + assert catalog.catalog_path == args.catalog_path + assert len(catalog.get_join_pixels()) == 14 + @pytest.mark.dask def test_source_to_object( small_sky_object_catalog, small_sky_source_catalog, tmp_path, - assert_text_file_matches, ): - """test stuff""" + """Test creating (weirder) association between source and object catalogs.""" args = AssociationArguments( primary_input_catalog_path=small_sky_source_catalog, @@ -128,40 +103,11 @@ def test_source_to_object( ) runner.run(args) - # Check that the catalog metadata file exists - expected_metadata_lines = [ - "{", - ' "catalog_name": "small_sky_association",', - ' "catalog_type": "association",', - ' "epoch": "J2000",', - ' "ra_kw": "ra",', - ' "dec_kw": "dec",', - ' "total_rows": 17161', - "}", - ] - metadata_filename = os.path.join(args.catalog_path, "catalog_info.json") - assert_text_file_matches(expected_metadata_lines, metadata_filename) - - # Check that the partition *join* info file exists - expected_lines = [ - "Norder,Dir,Npix,join_Norder,join_Dir,join_Npix,num_rows", - "0,0,4,0,0,11,50", - "1,0,47,0,0,11,2395", - "2,0,176,0,0,11,385", - "2,0,177,0,0,11,1510", - "2,0,178,0,0,11,1634", - "2,0,179,0,0,11,1773", - "2,0,180,0,0,11,655", - "2,0,181,0,0,11,903", - "2,0,182,0,0,11,1246", - "2,0,183,0,0,11,1143", - "2,0,184,0,0,11,1390", - "2,0,185,0,0,11,2942", - "2,0,186,0,0,11,452", - "2,0,187,0,0,11,683", - ] - metadata_filename = os.path.join(args.catalog_path, "partition_join_info.csv") - assert_text_file_matches(expected_lines, metadata_filename) + ## Check that the association data can be parsed as a valid association catalog. + catalog = AssociationCatalog.read_from_hipscat(args.catalog_path) + assert catalog.on_disk + assert catalog.catalog_path == args.catalog_path + assert len(catalog.get_join_pixels()) == 14 ## Test one pixel that will have 50 rows in it. output_file = os.path.join( @@ -192,9 +138,8 @@ def test_source_to_object( def test_self_join( small_sky_object_catalog, tmp_path, - assert_text_file_matches, ): - """test stuff""" + """Test creating association between object catalog and itself.""" args = AssociationArguments( primary_input_catalog_path=small_sky_object_catalog, @@ -209,27 +154,11 @@ def test_self_join( ) runner.run(args) - # Check that the catalog metadata file exists - expected_metadata_lines = [ - "{", - ' "catalog_name": "small_sky_self_association",', - ' "catalog_type": "association",', - ' "epoch": "J2000",', - ' "ra_kw": "ra",', - ' "dec_kw": "dec",', - ' "total_rows": 131', - "}", - ] - metadata_filename = os.path.join(args.catalog_path, "catalog_info.json") - assert_text_file_matches(expected_metadata_lines, metadata_filename) - - # Check that the partition *join* info file exists - expected_lines = [ - "Norder,Dir,Npix,join_Norder,join_Dir,join_Npix,num_rows", - "0,0,11,0,0,11,131", - ] - metadata_filename = os.path.join(args.catalog_path, "partition_join_info.csv") - assert_text_file_matches(expected_lines, metadata_filename) + ## Check that the association data can be parsed as a valid association catalog. + catalog = AssociationCatalog.read_from_hipscat(args.catalog_path) + assert catalog.on_disk + assert catalog.catalog_path == args.catalog_path + assert len(catalog.get_join_pixels()) == 1 ## Test one pixel that will have 50 rows in it. output_file = os.path.join( diff --git a/tests/hipscat_import/catalog/test_argument_validation.py b/tests/hipscat_import/catalog/test_argument_validation.py index 22a77bb1..9954415b 100644 --- a/tests/hipscat_import/catalog/test_argument_validation.py +++ b/tests/hipscat_import/catalog/test_argument_validation.py @@ -1,4 +1,4 @@ -"""Tests of argument validation, in the absense of command line parsing""" +"""Tests of argument validation""" import pytest @@ -168,7 +168,7 @@ def test_catalog_type(blank_data_dir, tmp_path): ) -def test_to_catalog_parameters(blank_data_dir, tmp_path): +def test_to_catalog_info(blank_data_dir, tmp_path): """Verify creation of catalog parameters for catalog to be created.""" args = ImportArguments( output_catalog_name="catalog", @@ -177,8 +177,9 @@ def test_to_catalog_parameters(blank_data_dir, tmp_path): output_path=tmp_path, tmp_dir=tmp_path, ) - catalog_parameters = args.to_catalog_parameters() - assert catalog_parameters.catalog_name == "catalog" + catalog_info = args.to_catalog_info(total_rows=10) + assert catalog_info.catalog_name == "catalog" + assert catalog_info.total_rows == 10 def test_provenance_info(blank_data_dir, tmp_path): diff --git a/tests/hipscat_import/catalog/test_arguments_commandline.py b/tests/hipscat_import/catalog/test_arguments_commandline.py deleted file mode 100644 index bff07f8c..00000000 --- a/tests/hipscat_import/catalog/test_arguments_commandline.py +++ /dev/null @@ -1,65 +0,0 @@ -"""Tests of command line argument validation""" - - -import pytest - -from hipscat_import.catalog.command_line_arguments import parse_command_line - -# pylint: disable=protected-access - - -def test_none(): - """No arguments provided. Should error for required args.""" - empty_args = [] - with pytest.raises(ValueError): - parse_command_line(empty_args) - - -def test_invalid_arguments(): - """Arguments are ill-formed.""" - bad_form_args = ["catalog", "path", "path"] - with pytest.raises(SystemExit): - parse_command_line(bad_form_args) - - -def test_invalid_path(): - """Required arguments are provided, but paths aren't found.""" - bad_path_args = ["-c", "catalog", "-i", "path", "-o", "path"] - with pytest.raises(FileNotFoundError): - parse_command_line(bad_path_args) - - -def test_good_paths(blank_data_dir, tmp_path): - """Required arguments are provided, and paths are found.""" - tmp_path_name = str(tmp_path) - good_args = [ - "--catalog_name", - "catalog", - "--input_path", - blank_data_dir, - "--output_path", - tmp_path_name, - "--input_format", - "csv", - ] - args = parse_command_line(good_args) - assert args.input_path == blank_data_dir - assert tmp_path_name in str(args.catalog_path) - - -def test_good_paths_short_names(blank_data_dir, tmp_path): - """Required arguments are provided, using short names for arguments.""" - tmp_path_name = str(tmp_path) - good_args = [ - "-c", - "catalog", - "-i", - blank_data_dir, - "-o", - tmp_path_name, - "-fmt", - "csv", - ] - args = parse_command_line(good_args) - assert args.input_path == blank_data_dir - assert tmp_path_name in str(args.catalog_path) diff --git a/tests/hipscat_import/catalog/test_file_readers.py b/tests/hipscat_import/catalog/test_file_readers.py index 87ef60b6..ff56229d 100644 --- a/tests/hipscat_import/catalog/test_file_readers.py +++ b/tests/hipscat_import/catalog/test_file_readers.py @@ -7,7 +7,7 @@ import pyarrow as pa import pyarrow.parquet as pq import pytest -from hipscat.catalog import CatalogParameters +from hipscat.catalog.catalog import CatalogInfo from hipscat_import.catalog.file_readers import ( CsvReader, @@ -17,6 +17,19 @@ ) +# pylint: disable=redefined-outer-name +@pytest.fixture +def basic_catalog_info(): + info = { + "catalog_name": "test_catalog", + "catalog_type": "object", + "total_rows": 100, + "ra_column": "ra", + "dec_column": "dec", + } + return CatalogInfo(**info) + + def test_unknown_file_type(): """File reader factory method should fail for unknown file types""" with pytest.raises(NotImplementedError): @@ -187,7 +200,7 @@ def test_csv_reader_pipe_delimited(formats_pipe_csv, tmp_path): assert np.all(column_types == expected_column_types) -def test_csv_reader_provenance_info(tmp_path): +def test_csv_reader_provenance_info(tmp_path, basic_catalog_info): """Test that we get some provenance info and it is parseable into JSON.""" reader = CsvReader( header=None, @@ -201,10 +214,9 @@ def test_csv_reader_provenance_info(tmp_path): }, ) provenance_info = reader.provenance_info() - base_catalog_parameters = CatalogParameters( - output_path=tmp_path, catalog_name="empty" - ) - io.write_provenance_info(base_catalog_parameters, provenance_info) + catalog_base_dir = os.path.join(tmp_path, "test_catalog") + os.makedirs(catalog_base_dir) + io.write_provenance_info(catalog_base_dir, basic_catalog_info, provenance_info) def test_parquet_reader(parquet_shards_shard_44_0): @@ -226,14 +238,13 @@ def test_parquet_reader_chunked(parquet_shards_shard_44_0): assert total_chunks == 7 -def test_parquet_reader_provenance_info(tmp_path): +def test_parquet_reader_provenance_info(tmp_path, basic_catalog_info): """Test that we get some provenance info and it is parseable into JSON.""" reader = ParquetReader(chunksize=1) provenance_info = reader.provenance_info() - base_catalog_parameters = CatalogParameters( - output_path=tmp_path, catalog_name="empty" - ) - io.write_provenance_info(base_catalog_parameters, provenance_info) + catalog_base_dir = os.path.join(tmp_path, "test_catalog") + os.makedirs(catalog_base_dir) + io.write_provenance_info(catalog_base_dir, basic_catalog_info, provenance_info) def test_read_fits(formats_fits): @@ -267,11 +278,10 @@ def test_read_fits_columns(formats_fits): assert list(frame.columns) == ["id", "ra", "dec"] -def test_fits_reader_provenance_info(tmp_path): +def test_fits_reader_provenance_info(tmp_path, basic_catalog_info): """Test that we get some provenance info and it is parseable into JSON.""" reader = FitsReader() provenance_info = reader.provenance_info() - base_catalog_parameters = CatalogParameters( - output_path=tmp_path, catalog_name="empty" - ) - io.write_provenance_info(base_catalog_parameters, provenance_info) + catalog_base_dir = os.path.join(tmp_path, "test_catalog") + os.makedirs(catalog_base_dir) + io.write_provenance_info(catalog_base_dir, basic_catalog_info, provenance_info) diff --git a/tests/hipscat_import/catalog/test_map_reduce.py b/tests/hipscat_import/catalog/test_map_reduce.py index 7589b2f0..356ac78c 100644 --- a/tests/hipscat_import/catalog/test_map_reduce.py +++ b/tests/hipscat_import/catalog/test_map_reduce.py @@ -194,9 +194,7 @@ def test_reduce_order0(parquet_shards_dir, assert_parquet_file_ids, tmp_path): assert_parquet_file_ids(output_file, "id", expected_ids) -def test_reduce_hipscat_index( - parquet_shards_dir, assert_parquet_file_ids, tmp_path -): +def test_reduce_hipscat_index(parquet_shards_dir, assert_parquet_file_ids, tmp_path): """Test reducing with or without a _hipscat_index field""" mr.reduce_pixel_shards( cache_path=parquet_shards_dir, diff --git a/tests/hipscat_import/catalog/test_run_import.py b/tests/hipscat_import/catalog/test_run_import.py index bd565996..50b6efff 100644 --- a/tests/hipscat_import/catalog/test_run_import.py +++ b/tests/hipscat_import/catalog/test_run_import.py @@ -4,6 +4,7 @@ import shutil import pytest +from hipscat.catalog.catalog import Catalog import hipscat_import.catalog.resume_files as rf import hipscat_import.catalog.run_import as runner @@ -12,15 +13,15 @@ def test_empty_args(): """Runner should fail with empty arguments""" - with pytest.raises(ValueError): - runner.run(None) + with pytest.raises(ValueError, match="args is required"): + runner.run(None, None) def test_bad_args(): """Runner should fail with mis-typed arguments""" args = {"output_catalog_name": "bad_arg_type"} - with pytest.raises(ValueError): - runner.run(args) + with pytest.raises(ValueError, match="ImportArguments"): + runner.run(args, None) @pytest.mark.dask @@ -29,7 +30,6 @@ def test_resume_dask_runner( small_sky_parts_dir, resume_dir, tmp_path, - assert_text_file_matches, assert_parquet_file_ids, ): """Test execution in the presence of some resume files.""" @@ -88,29 +88,16 @@ def test_resume_dask_runner( progress_bar=False, ) - runner.run_with_client(args, dask_client) + runner.run(args, dask_client) # Check that the catalog metadata file exists - expected_metadata_lines = [ - "{", - ' "catalog_name": "resume",', - ' "catalog_type": "object",', - ' "epoch": "J2000",', - ' "ra_kw": "ra",', - ' "dec_kw": "dec",', - ' "total_rows": 131', - "}", - ] - metadata_filename = os.path.join(args.catalog_path, "catalog_info.json") - assert_text_file_matches(expected_metadata_lines, metadata_filename) - - # Check that the partition info file exists - expected_partition_lines = [ - "Norder,Dir,Npix,num_rows", - "0,0,11,131", - ] - partition_filename = os.path.join(args.catalog_path, "partition_info.csv") - assert_text_file_matches(expected_partition_lines, partition_filename) + catalog = Catalog.read_from_hipscat(args.catalog_path) + assert catalog.on_disk + assert catalog.catalog_path == args.catalog_path + assert catalog.catalog_info.ra_column == "ra" + assert catalog.catalog_info.dec_column == "dec" + assert catalog.catalog_info.total_rows == 131 + assert len(catalog.get_pixels()) == 1 # Check that the catalog parquet file exists and contains correct object IDs output_file = os.path.join( @@ -144,10 +131,15 @@ def test_resume_dask_runner( progress_bar=False, ) - runner.run_with_client(args, dask_client) + runner.run(args, dask_client) - assert_text_file_matches(expected_metadata_lines, metadata_filename) - assert_text_file_matches(expected_partition_lines, partition_filename) + catalog = Catalog.read_from_hipscat(args.catalog_path) + assert catalog.on_disk + assert catalog.catalog_path == args.catalog_path + assert catalog.catalog_info.ra_column == "ra" + assert catalog.catalog_info.dec_column == "dec" + assert catalog.catalog_info.total_rows == 131 + assert len(catalog.get_pixels()) == 1 assert_parquet_file_ids(output_file, "id", expected_ids) @@ -156,7 +148,6 @@ def test_dask_runner( dask_client, small_sky_parts_dir, assert_parquet_file_ids, - assert_text_file_matches, tmp_path, ): """Test basic execution.""" @@ -170,29 +161,16 @@ def test_dask_runner( progress_bar=False, ) - runner.run_with_client(args, dask_client) + runner.run(args, dask_client) # Check that the catalog metadata file exists - expected_lines = [ - "{", - ' "catalog_name": "small_sky_object_catalog",', - ' "catalog_type": "object",', - ' "epoch": "J2000",', - ' "ra_kw": "ra",', - ' "dec_kw": "dec",', - ' "total_rows": 131', - "}", - ] - metadata_filename = os.path.join(args.catalog_path, "catalog_info.json") - assert_text_file_matches(expected_lines, metadata_filename) - - # Check that the partition info file exists - expected_lines = [ - "Norder,Dir,Npix,num_rows", - "0,0,11,131", - ] - metadata_filename = os.path.join(args.catalog_path, "partition_info.csv") - assert_text_file_matches(expected_lines, metadata_filename) + catalog = Catalog.read_from_hipscat(args.catalog_path) + assert catalog.on_disk + assert catalog.catalog_path == args.catalog_path + assert catalog.catalog_info.ra_column == "ra" + assert catalog.catalog_info.dec_column == "dec" + assert catalog.catalog_info.total_rows == 131 + assert len(catalog.get_pixels()) == 1 # Check that the catalog parquet file exists and contains correct object IDs output_file = os.path.join( @@ -217,7 +195,7 @@ def test_dask_runner_stats_only(dask_client, small_sky_parts_dir, tmp_path): debug_stats_only=True, ) - runner.run_with_client(args, dask_client) + runner.run(args, dask_client) metadata_filename = os.path.join(args.catalog_path, "catalog_info.json") assert os.path.exists(metadata_filename) diff --git a/tests/hipscat_import/catalog/test_run_round_trip.py b/tests/hipscat_import/catalog/test_run_round_trip.py index 174c09a9..49484164 100644 --- a/tests/hipscat_import/catalog/test_run_round_trip.py +++ b/tests/hipscat_import/catalog/test_run_round_trip.py @@ -11,17 +11,17 @@ import numpy.testing as npt import pandas as pd import pytest +from hipscat.catalog.catalog import Catalog import hipscat_import.catalog.run_import as runner from hipscat_import.catalog.arguments import ImportArguments -from hipscat_import.catalog.file_readers import get_file_reader +from hipscat_import.catalog.file_readers import CsvReader, get_file_reader @pytest.mark.dask def test_import_source_table( dask_client, small_sky_source_dir, - assert_text_file_matches, tmp_path, ): """Test basic execution, using a larger source file. @@ -44,42 +44,15 @@ def test_import_source_table( progress_bar=False, ) - runner.run_with_client(args, dask_client) + runner.run(args, dask_client) # Check that the catalog metadata file exists - expected_lines = [ - "{", - ' "catalog_name": "small_sky_source_catalog",', - ' "catalog_type": "source",', - ' "epoch": "J2000",', - ' "ra_kw": "source_ra",', - ' "dec_kw": "source_dec",', - ' "total_rows": 17161', - "}", - ] - metadata_filename = os.path.join(args.catalog_path, "catalog_info.json") - assert_text_file_matches(expected_lines, metadata_filename) - - # Check that the partition info file exists - expected_lines = [ - "Norder,Dir,Npix,num_rows", - "0,0,4,50", - "1,0,47,2395", - "2,0,176,385", - "2,0,177,1510", - "2,0,178,1634", - "2,0,179,1773", - "2,0,180,655", - "2,0,181,903", - "2,0,182,1246", - "2,0,183,1143", - "2,0,184,1390", - "2,0,185,2942", - "2,0,186,452", - "2,0,187,683", - ] - metadata_filename = os.path.join(args.catalog_path, "partition_info.csv") - assert_text_file_matches(expected_lines, metadata_filename) + catalog = Catalog.read_from_hipscat(args.catalog_path) + assert catalog.on_disk + assert catalog.catalog_path == args.catalog_path + assert catalog.catalog_info.ra_column == "source_ra" + assert catalog.catalog_info.dec_column == "source_dec" + assert len(catalog.get_pixels()) == 14 @pytest.mark.dask @@ -111,7 +84,7 @@ def test_import_mixed_schema_csv( use_schema_file=mixed_schema_csv_parquet, ) - runner.run_with_client(args, dask_client) + runner.run(args, dask_client) # Check that the catalog parquet file exists output_file = os.path.join( @@ -131,8 +104,8 @@ def test_import_preserve_index( ): """Test basic execution, with input with pandas metadata. - the input file is a parquet file with some pandas metadata. - this verifies that the parquet file at the end also has pandas - metadata, and the user's preferred id is retained as the index, + this verifies that the parquet file at the end also has pandas + metadata, and the user's preferred id is retained as the index, when requested. """ @@ -167,7 +140,7 @@ def test_import_preserve_index( progress_bar=False, ) - runner.run_with_client(args, dask_client) + runner.run(args, dask_client) # Check that the catalog parquet file exists output_file = os.path.join( @@ -195,7 +168,7 @@ def test_import_preserve_index( progress_bar=False, ) - runner.run_with_client(args, dask_client) + runner.run(args, dask_client) # Check that the catalog parquet file exists output_file = os.path.join( @@ -222,8 +195,8 @@ def test_import_multiindex( """Test basic execution, with input with pandas metadata - this is *similar* to the above test - the input file is a parquet file with a multi-level pandas index. - this verifies that the parquet file at the end also has pandas - metadata, and the user's preferred id is retained as the index, + this verifies that the parquet file at the end also has pandas + metadata, and the user's preferred id is retained as the index, when requested. """ @@ -262,7 +235,7 @@ def test_import_multiindex( progress_bar=False, ) - runner.run_with_client(args, dask_client) + runner.run(args, dask_client) # Check that the catalog parquet file exists output_file = os.path.join( @@ -290,7 +263,7 @@ def test_import_multiindex( progress_bar=False, ) - runner.run_with_client(args, dask_client) + runner.run(args, dask_client) # Check that the catalog parquet file exists output_file = os.path.join( @@ -310,7 +283,6 @@ def test_import_multiindex( def test_import_constant_healpix_order( dask_client, small_sky_parts_dir, - assert_text_file_matches, tmp_path, ): """Test basic execution. @@ -327,42 +299,16 @@ def test_import_constant_healpix_order( progress_bar=False, ) - runner.run_with_client(args, dask_client) - - # Check that the partition info file exists - all pixels at order 2! - expected_lines = [ - "Norder,Dir,Npix,num_rows", - "2,0,176,4", - "2,0,177,11", - "2,0,178,14", - "2,0,179,13", - "2,0,180,5", - "2,0,181,7", - "2,0,182,8", - "2,0,183,9", - "2,0,184,11", - "2,0,185,23", - "2,0,186,4", - "2,0,187,4", - "2,0,188,17", - "2,0,190,1", - ] - metadata_filename = os.path.join(args.catalog_path, "partition_info.csv") - assert_text_file_matches(expected_lines, metadata_filename) + runner.run(args, dask_client) # Check that the catalog metadata file exists - expected_lines = [ - "{", - ' "catalog_name": "small_sky_object_catalog",', - ' "catalog_type": "object",', - ' "epoch": "J2000",', - ' "ra_kw": "ra",', - ' "dec_kw": "dec",', - ' "total_rows": 131', - "}", - ] - metadata_filename = os.path.join(args.catalog_path, "catalog_info.json") - assert_text_file_matches(expected_lines, metadata_filename) + catalog = Catalog.read_from_hipscat(args.catalog_path) + assert catalog.on_disk + assert catalog.catalog_path == args.catalog_path + # Check that the partition info file exists - all pixels at order 2! + assert all( + pixel.order == 2 for pixel in catalog.partition_info.get_healpix_pixels() + ) # Pick a parquet file and make sure it contains as many rows as we expect output_file = os.path.join( @@ -373,3 +319,47 @@ def test_import_constant_healpix_order( assert len(data_frame) == 14 ids = data_frame["id"] assert np.logical_and(ids >= 700, ids < 832).all() + +@pytest.mark.dask +def test_import_starr_file( + dask_client, + formats_dir, + assert_parquet_file_ids, + tmp_path, +): + """Test basic execution. + - tests that we can run pipeline with a totally unknown file type, so long + as a valid InputReader implementation is provided. + """ + + class StarrReader(CsvReader): + """Shallow subclass""" + + args = ImportArguments( + output_catalog_name="starr", + input_path=formats_dir, + input_format="starr", + file_reader=StarrReader(), + output_path=tmp_path, + dask_tmp=tmp_path, + highest_healpix_order=2, + pixel_threshold=3_000, + progress_bar=False, + ) + + runner.run(args, dask_client) + + # Check that the catalog metadata file exists + catalog = Catalog.read_from_hipscat(args.catalog_path) + assert catalog.on_disk + assert catalog.catalog_path == args.catalog_path + assert catalog.catalog_info.total_rows == 131 + assert len(catalog.get_pixels()) == 1 + + # Check that the catalog parquet file exists and contains correct object IDs + output_file = os.path.join( + args.catalog_path, "Norder=0", "Dir=0", "Npix=11.parquet" + ) + + expected_ids = [*range(700, 831)] + assert_parquet_file_ids(output_file, "id", expected_ids) \ No newline at end of file diff --git a/tests/hipscat_import/conftest.py b/tests/hipscat_import/conftest.py index 9e091137..fe2d4e9c 100644 --- a/tests/hipscat_import/conftest.py +++ b/tests/hipscat_import/conftest.py @@ -93,6 +93,11 @@ def empty_data_dir(test_data_dir): return os.path.join(test_data_dir, "empty") +@pytest.fixture +def formats_dir(test_data_dir): + return os.path.join(test_data_dir, "test_formats") + + @pytest.fixture def formats_headers_csv(test_data_dir): return os.path.join(test_data_dir, "test_formats", "headers.csv") diff --git a/tests/hipscat_import/data/small_sky_object_catalog/catalog_info.json b/tests/hipscat_import/data/small_sky_object_catalog/catalog_info.json index 451af75e..59eaef9c 100644 --- a/tests/hipscat_import/data/small_sky_object_catalog/catalog_info.json +++ b/tests/hipscat_import/data/small_sky_object_catalog/catalog_info.json @@ -1,8 +1,8 @@ { "catalog_name": "small_sky_object_catalog", "catalog_type": "object", + "total_rows": 131, "epoch": "J2000", - "ra_kw": "ra", - "dec_kw": "dec", - "total_rows": 131 -} + "ra_column": "ra", + "dec_column": "dec" +} \ No newline at end of file diff --git a/tests/hipscat_import/data/small_sky_source_catalog/catalog_info.json b/tests/hipscat_import/data/small_sky_source_catalog/catalog_info.json index 7c819f1d..0491b5c5 100644 --- a/tests/hipscat_import/data/small_sky_source_catalog/catalog_info.json +++ b/tests/hipscat_import/data/small_sky_source_catalog/catalog_info.json @@ -1,8 +1,8 @@ { "catalog_name": "small_sky_source_catalog", "catalog_type": "source", + "total_rows": 17161, "epoch": "J2000", "ra_column": "source_ra", - "dec_column": "source_dec", - "total_rows": 17161 -} + "dec_column": "source_dec" +} \ No newline at end of file diff --git a/tests/hipscat_import/data/test_formats/catalog.starr b/tests/hipscat_import/data/test_formats/catalog.starr new file mode 100644 index 00000000..1bb66e95 --- /dev/null +++ b/tests/hipscat_import/data/test_formats/catalog.starr @@ -0,0 +1,132 @@ +id,ra,dec,ra_error,dec_error +700,282.5,-58.5,0,0 +701,299.5,-48.5,0,0 +702,310.5,-27.5,0,0 +703,286.5,-69.5,0,0 +704,326.5,-45.5,0,0 +705,335.5,-32.5,0,0 +706,297.5,-36.5,0,0 +707,308.5,-69.5,0,0 +708,307.5,-37.5,0,0 +709,294.5,-45.5,0,0 +710,341.5,-39.5,0,0 +711,305.5,-49.5,0,0 +712,288.5,-49.5,0,0 +713,298.5,-41.5,0,0 +714,303.5,-37.5,0,0 +715,280.5,-35.5,0,0 +716,305.5,-60.5,0,0 +717,303.5,-43.5,0,0 +718,292.5,-60.5,0,0 +719,344.5,-39.5,0,0 +720,344.5,-47.5,0,0 +721,314.5,-34.5,0,0 +722,350.5,-58.5,0,0 +723,315.5,-68.5,0,0 +724,323.5,-41.5,0,0 +725,308.5,-41.5,0,0 +726,341.5,-37.5,0,0 +727,301.5,-44.5,0,0 +728,328.5,-47.5,0,0 +729,299.5,-59.5,0,0 +730,322.5,-61.5,0,0 +731,343.5,-52.5,0,0 +732,337.5,-39.5,0,0 +733,329.5,-65.5,0,0 +734,348.5,-66.5,0,0 +735,299.5,-65.5,0,0 +736,303.5,-52.5,0,0 +737,316.5,-33.5,0,0 +738,345.5,-64.5,0,0 +739,332.5,-57.5,0,0 +740,306.5,-33.5,0,0 +741,303.5,-38.5,0,0 +742,348.5,-45.5,0,0 +743,307.5,-25.5,0,0 +744,349.5,-39.5,0,0 +745,337.5,-38.5,0,0 +746,283.5,-31.5,0,0 +747,327.5,-61.5,0,0 +748,296.5,-63.5,0,0 +749,293.5,-55.5,0,0 +750,338.5,-67.5,0,0 +751,330.5,-44.5,0,0 +752,291.5,-34.5,0,0 +753,307.5,-45.5,0,0 +754,313.5,-30.5,0,0 +755,303.5,-38.5,0,0 +756,319.5,-35.5,0,0 +757,346.5,-34.5,0,0 +758,325.5,-53.5,0,0 +759,290.5,-48.5,0,0 +760,320.5,-53.5,0,0 +761,329.5,-29.5,0,0 +762,327.5,-51.5,0,0 +763,306.5,-38.5,0,0 +764,297.5,-45.5,0,0 +765,306.5,-35.5,0,0 +766,310.5,-63.5,0,0 +767,314.5,-29.5,0,0 +768,297.5,-60.5,0,0 +769,307.5,-42.5,0,0 +770,285.5,-29.5,0,0 +771,348.5,-67.5,0,0 +772,348.5,-64.5,0,0 +773,293.5,-50.5,0,0 +774,281.5,-54.5,0,0 +775,321.5,-54.5,0,0 +776,344.5,-63.5,0,0 +777,307.5,-39.5,0,0 +778,313.5,-36.5,0,0 +779,347.5,-29.5,0,0 +780,326.5,-52.5,0,0 +781,330.5,-46.5,0,0 +782,290.5,-39.5,0,0 +783,286.5,-42.5,0,0 +784,338.5,-40.5,0,0 +785,296.5,-44.5,0,0 +786,336.5,-33.5,0,0 +787,320.5,-47.5,0,0 +788,283.5,-61.5,0,0 +789,287.5,-45.5,0,0 +790,286.5,-35.5,0,0 +791,312.5,-28.5,0,0 +792,320.5,-69.5,0,0 +793,289.5,-58.5,0,0 +794,300.5,-66.5,0,0 +795,306.5,-58.5,0,0 +796,320.5,-33.5,0,0 +797,308.5,-62.5,0,0 +798,316.5,-36.5,0,0 +799,313.5,-31.5,0,0 +800,299.5,-37.5,0,0 +801,309.5,-50.5,0,0 +802,304.5,-49.5,0,0 +803,336.5,-25.5,0,0 +804,322.5,-66.5,0,0 +805,297.5,-52.5,0,0 +806,312.5,-29.5,0,0 +807,303.5,-60.5,0,0 +808,320.5,-40.5,0,0 +809,283.5,-34.5,0,0 +810,301.5,-59.5,0,0 +811,315.5,-68.5,0,0 +812,346.5,-60.5,0,0 +813,349.5,-37.5,0,0 +814,312.5,-33.5,0,0 +815,283.5,-68.5,0,0 +816,288.5,-69.5,0,0 +817,318.5,-48.5,0,0 +818,300.5,-55.5,0,0 +819,313.5,-35.5,0,0 +820,286.5,-46.5,0,0 +821,330.5,-52.5,0,0 +822,301.5,-54.5,0,0 +823,338.5,-45.5,0,0 +824,305.5,-28.5,0,0 +825,315.5,-30.5,0,0 +826,335.5,-69.5,0,0 +827,310.5,-40.5,0,0 +828,330.5,-26.5,0,0 +829,314.5,-35.5,0,0 +830,306.5,-50.5,0,0 \ No newline at end of file diff --git a/tests/hipscat_import/index/test_index_argument.py b/tests/hipscat_import/index/test_index_argument.py index c7032806..cf7036a4 100644 --- a/tests/hipscat_import/index/test_index_argument.py +++ b/tests/hipscat_import/index/test_index_argument.py @@ -1,4 +1,4 @@ -"""Tests of argument validation, in the absense of command line parsing""" +"""Tests of argument validation""" import pytest @@ -128,7 +128,7 @@ def test_compute_partition_size(tmp_path, small_sky_object_catalog): ) -def test_to_catalog_parameters(small_sky_object_catalog, tmp_path): +def test_to_catalog_info(small_sky_object_catalog, tmp_path): """Verify creation of catalog parameters for index to be created.""" args = IndexArguments( input_catalog_path=small_sky_object_catalog, @@ -138,8 +138,9 @@ def test_to_catalog_parameters(small_sky_object_catalog, tmp_path): include_hipscat_index=True, include_order_pixel=True, ) - catalog_parameters = args.to_catalog_parameters() - assert catalog_parameters.catalog_name == args.output_catalog_name + catalog_info = args.to_catalog_info(total_rows=10) + assert catalog_info.catalog_name == args.output_catalog_name + assert catalog_info.total_rows == 10 def test_provenance_info(small_sky_object_catalog, tmp_path): diff --git a/tests/hipscat_import/index/test_run_index.py b/tests/hipscat_import/index/test_run_index.py index 0f84cd61..f9bf3d3e 100644 --- a/tests/hipscat_import/index/test_run_index.py +++ b/tests/hipscat_import/index/test_run_index.py @@ -5,6 +5,7 @@ import pyarrow as pa import pyarrow.parquet as pq import pytest +from hipscat.catalog.dataset.dataset import Dataset import hipscat_import.index.run_index as runner from hipscat_import.index.arguments import IndexArguments @@ -27,7 +28,6 @@ def test_bad_args(): def test_run_index( small_sky_object_catalog, tmp_path, - assert_text_file_matches, ): """Test appropriate metadata is written""" @@ -42,18 +42,9 @@ def test_run_index( runner.run(args) # Check that the catalog metadata file exists - expected_metadata_lines = [ - "{", - ' "catalog_name": "small_sky_object_index",', - ' "catalog_type": "index",', - ' "epoch": "J2000",', - ' "ra_kw": "ra",', - ' "dec_kw": "dec",', - ' "total_rows": 131', - "}", - ] - metadata_filename = os.path.join(args.catalog_path, "catalog_info.json") - assert_text_file_matches(expected_metadata_lines, metadata_filename) + catalog = Dataset.read_from_hipscat(args.catalog_path) + assert catalog.on_disk + assert catalog.catalog_path == args.catalog_path basic_index_parquet_schema = pa.schema( [ diff --git a/tests/hipscat_import/margin_cache/test_arguments_margin_cache.py b/tests/hipscat_import/margin_cache/test_arguments_margin_cache.py index 0d1537af..00bab144 100644 --- a/tests/hipscat_import/margin_cache/test_arguments_margin_cache.py +++ b/tests/hipscat_import/margin_cache/test_arguments_margin_cache.py @@ -7,6 +7,7 @@ # pylint: disable=protected-access + def test_empty_required(tmp_path): """*Most* required arguments are provided.""" ## Input catalog path is missing @@ -17,6 +18,7 @@ def test_empty_required(tmp_path): output_catalog_name="catalog_cache", ) + def test_margin_order_dynamic(small_sky_source_catalog, tmp_path): """Ensure we can dynamically set the margin_order""" args = MarginCacheArguments( @@ -28,6 +30,7 @@ def test_margin_order_dynamic(small_sky_source_catalog, tmp_path): assert args.margin_order == 3 + def test_margin_order_static(small_sky_source_catalog, tmp_path): """Ensure we can manually set the margin_order""" args = MarginCacheArguments( @@ -35,11 +38,12 @@ def test_margin_order_static(small_sky_source_catalog, tmp_path): input_catalog_path=small_sky_source_catalog, output_path=tmp_path, output_catalog_name="catalog_cache", - margin_order=4 + margin_order=4, ) assert args.margin_order == 4 + def test_margin_order_invalid(small_sky_source_catalog, tmp_path): """Ensure we raise an exception when margin_order is invalid""" with pytest.raises(ValueError, match="margin_order"): @@ -48,9 +52,10 @@ def test_margin_order_invalid(small_sky_source_catalog, tmp_path): input_catalog_path=small_sky_source_catalog, output_path=tmp_path, output_catalog_name="catalog_cache", - margin_order=2 + margin_order=2, ) + def test_margin_threshold_warns(small_sky_source_catalog, tmp_path): """Ensure we give a warning when margin_threshold is greater than margin_order resolution""" @@ -60,5 +65,33 @@ def test_margin_threshold_warns(small_sky_source_catalog, tmp_path): input_catalog_path=small_sky_source_catalog, output_path=tmp_path, output_catalog_name="catalog_cache", - margin_order=16 + margin_order=16, ) + + +def test_to_catalog_info(small_sky_source_catalog, tmp_path): + """Verify creation of catalog info for margin cache to be created.""" + args = MarginCacheArguments( + margin_threshold=5.0, + input_catalog_path=small_sky_source_catalog, + output_path=tmp_path, + output_catalog_name="catalog_cache", + margin_order=4, + ) + catalog_info = args.to_catalog_info(total_rows=10) + assert catalog_info.catalog_name == args.output_catalog_name + assert catalog_info.total_rows == 10 + + +def test_provenance_info(small_sky_source_catalog, tmp_path): + """Verify that provenance info includes margin-cache-specific fields.""" + args = MarginCacheArguments( + margin_threshold=5.0, + input_catalog_path=small_sky_source_catalog, + output_path=tmp_path, + output_catalog_name="catalog_cache", + margin_order=4, + ) + + runtime_args = args.provenance_info()["runtime_args"] + assert "margin_threshold" in runtime_args diff --git a/tests/hipscat_import/margin_cache/test_margin_cache.py b/tests/hipscat_import/margin_cache/test_margin_cache.py index cfdafb05..56f13389 100644 --- a/tests/hipscat_import/margin_cache/test_margin_cache.py +++ b/tests/hipscat_import/margin_cache/test_margin_cache.py @@ -8,6 +8,8 @@ import hipscat_import.margin_cache.margin_cache as mc from hipscat_import.margin_cache import MarginCacheArguments +# pylint: disable=protected-access + @pytest.mark.dask(timeout=20) def test_margin_cache_gen(small_sky_source_catalog, tmp_path, dask_client): @@ -21,7 +23,7 @@ def test_margin_cache_gen(small_sky_source_catalog, tmp_path, dask_client): assert args.catalog.catalog_info.ra_column == "source_ra" - mc.generate_margin_cache_with_client(dask_client, args) + mc.generate_margin_cache(args, dask_client) print(args.catalog.partition_info.get_healpix_pixels()) @@ -45,8 +47,7 @@ def test_partition_margin_pixel_pairs(small_sky_source_catalog, tmp_path): ) margin_pairs = mc._find_partition_margin_pixel_pairs( - args.catalog.partition_info.get_healpix_pixels(), - args.margin_order + args.catalog.partition_info.get_healpix_pixels(), args.margin_order ) expected = np.array([725, 733, 757, 765, 727, 735, 759, 767, 469, 192]) @@ -54,6 +55,7 @@ def test_partition_margin_pixel_pairs(small_sky_source_catalog, tmp_path): npt.assert_array_equal(margin_pairs.iloc[:10]["margin_pixel"], expected) assert len(margin_pairs) == 196 + def test_create_margin_directory(small_sky_source_catalog, tmp_path): args = MarginCacheArguments( margin_threshold=5.0, @@ -64,10 +66,8 @@ def test_create_margin_directory(small_sky_source_catalog, tmp_path): mc._create_margin_directory( stats=args.catalog.partition_info.get_healpix_pixels(), - output_path=args.catalog_path + output_path=args.catalog_path, ) - output = file_io.append_paths_to_pointer( - args.catalog_path, "Norder=0", "Dir=0" - ) + output = file_io.append_paths_to_pointer(args.catalog_path, "Norder=0", "Dir=0") assert file_io.does_file_or_directory_exist(output) diff --git a/tests/hipscat_import/margin_cache/test_margin_cache_map_reduce.py b/tests/hipscat_import/margin_cache/test_margin_cache_map_reduce.py index 9db54bf5..e5ae1d47 100644 --- a/tests/hipscat_import/margin_cache/test_margin_cache_map_reduce.py +++ b/tests/hipscat_import/margin_cache/test_margin_cache_map_reduce.py @@ -4,19 +4,17 @@ from hipscat_import.margin_cache import margin_cache_map_reduce -keep_cols = [ - "weird_ra", - "weird_dec" -] +keep_cols = ["weird_ra", "weird_dec"] drop_cols = [ - "partition_order", - "partition_pixel", - "margin_check", + "partition_order", + "partition_pixel", + "margin_check", "margin_pixel", - "is_trunc" + "is_trunc", ] + def validate_result_dataframe(df_path, expected_len): res_df = pd.read_parquet(df_path) @@ -38,12 +36,11 @@ def test_to_pixel_shard_equator(tmp_path, basic_data_shard_df): output_path=tmp_path, margin_order=3, ra_column="weird_ra", - dec_column="weird_dec" + dec_column="weird_dec", ) path = file_io.append_paths_to_pointer( - tmp_path, - "Norder=1/Dir=0/Npix=21/Norder=1/Dir=0/Npix=0.parquet" + tmp_path, "Norder=1/Dir=0/Npix=21/Norder=1/Dir=0/Npix=0.parquet" ) assert file_io.does_file_or_directory_exist(path) @@ -58,12 +55,11 @@ def test_to_pixel_shard_polar(tmp_path, polar_data_shard_df): output_path=tmp_path, margin_order=3, ra_column="weird_ra", - dec_column="weird_dec" + dec_column="weird_dec", ) path = file_io.append_paths_to_pointer( - tmp_path, - "Norder=2/Dir=0/Npix=15/Norder=2/Dir=0/Npix=0.parquet" + tmp_path, "Norder=2/Dir=0/Npix=15/Norder=2/Dir=0/Npix=0.parquet" ) assert file_io.does_file_or_directory_exist(path) diff --git a/tests/hipscat_import/test_runtime_arguments.py b/tests/hipscat_import/test_runtime_arguments.py index 97dc4ba5..7fc3c087 100644 --- a/tests/hipscat_import/test_runtime_arguments.py +++ b/tests/hipscat_import/test_runtime_arguments.py @@ -1,4 +1,4 @@ -"""Tests of argument validation, in the absense of command line parsing""" +"""Tests of argument validation""" import os