diff --git a/.github/workflows/testing-and-coverage.yml b/.github/workflows/testing-and-coverage.yml index 4507157b..b497a967 100644 --- a/.github/workflows/testing-and-coverage.yml +++ b/.github/workflows/testing-and-coverage.yml @@ -27,8 +27,7 @@ jobs: run: | sudo apt-get update python -m pip install --upgrade pip - pip install . - pip install .[dev] + pip install -e .[dev] if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - name: Run unit tests with pytest run: | diff --git a/docs/notebooks/intro_notebook.ipynb b/docs/notebooks/intro_notebook.ipynb index 2e7779f5..71086d5e 100644 --- a/docs/notebooks/intro_notebook.ipynb +++ b/docs/notebooks/intro_notebook.ipynb @@ -65,9 +65,9 @@ " # Calculate the cosine of each value of X\n", " z = np.cos(x)\n", " # Plot the sine wave in blue, using degrees rather than radians on the X axis\n", - " pl.plot(xdeg, y, color='blue', label='Sine wave')\n", + " pl.plot(xdeg, y, color=\"blue\", label=\"Sine wave\")\n", " # Plot the cos wave in green, using degrees rather than radians on the X axis\n", - " pl.plot(xdeg, z, color='green', label='Cosine wave')\n", + " pl.plot(xdeg, z, color=\"green\", label=\"Cosine wave\")\n", " pl.xlabel(\"Degrees\")\n", " # More sensible X axis values\n", " pl.xticks(np.arange(0, 361, 45))\n", diff --git a/pyproject.toml b/pyproject.toml index 9fbb62d9..64e6c3ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,12 +62,15 @@ write_to = "src/hipscat_import/_version.py" [tool.setuptools.package-data] hipscat_import = ["py.typed"] -[project.scripts] -hipscat-import = "hipscat_import.control:main" -hc = "hipscat_import.control:main" [tool.pytest.ini_options] timeout = 1 markers = [ "dask: mark tests as having a dask client runtime dependency", +] + +[tool.coverage.report] +omit = [ + "src/hipscat_import/_version.py", # auto-generated + "src/hipscat_import/pipeline.py", # too annoying to test ] \ No newline at end of file diff --git a/src/hipscat_import/association/run_association.py b/src/hipscat_import/association/run_association.py index 461a36f1..162fd013 100644 --- a/src/hipscat_import/association/run_association.py +++ b/src/hipscat_import/association/run_association.py @@ -9,21 +9,16 @@ from tqdm import tqdm from hipscat_import.association.arguments import AssociationArguments -from hipscat_import.association.map_reduce import (map_association, - reduce_association) +from hipscat_import.association.map_reduce import map_association, reduce_association -def _validate_args(args): +def run(args): + """Run the association pipeline""" if not args: raise TypeError("args is required and should be type AssociationArguments") if not isinstance(args, AssociationArguments): raise TypeError("args must be type AssociationArguments") - -def run(args): - """Run the association pipeline""" - _validate_args(args) - with tqdm(total=1, desc="Mapping ", disable=not args.progress_bar) as step_progress: map_association(args) step_progress.update(1) diff --git a/src/hipscat_import/catalog/__init__.py b/src/hipscat_import/catalog/__init__.py index a2b40f62..fe9a1e61 100644 --- a/src/hipscat_import/catalog/__init__.py +++ b/src/hipscat_import/catalog/__init__.py @@ -1,13 +1,26 @@ """All modules for importing new catalogs.""" from .arguments import ImportArguments -from .file_readers import (CsvReader, FitsReader, InputReader, ParquetReader, - get_file_reader) +from .file_readers import ( + CsvReader, + FitsReader, + InputReader, + ParquetReader, + get_file_reader, +) from .map_reduce import map_to_pixels, reduce_pixel_shards, split_pixels -from .resume_files import (clean_resume_files, is_mapping_done, - is_reducing_done, read_histogram, read_mapping_keys, - read_reducing_keys, set_mapping_done, - set_reducing_done, write_histogram, - write_mapping_done_key, write_mapping_start_key, - write_reducing_key) -from .run_import import run, run_with_client +from .resume_files import ( + clean_resume_files, + is_mapping_done, + is_reducing_done, + read_histogram, + read_mapping_keys, + read_reducing_keys, + set_mapping_done, + set_reducing_done, + write_histogram, + write_mapping_done_key, + write_mapping_start_key, + write_reducing_key, +) +from .run_import import run diff --git a/src/hipscat_import/catalog/map_reduce.py b/src/hipscat_import/catalog/map_reduce.py index cc75dc70..34fa30a5 100644 --- a/src/hipscat_import/catalog/map_reduce.py +++ b/src/hipscat_import/catalog/map_reduce.py @@ -98,7 +98,7 @@ def map_to_pixels( Args: input_file (FilePointer): file to read for catalog data. - file_reader (hipscat_import.catalog.file_readers.InputReader): instance of input + file_reader (hipscat_import.catalog.file_readers.InputReader): instance of input reader that specifies arguments necessary for reading from the input file. shard_suffix (str): unique counter for this input file, used when creating intermediate files @@ -137,7 +137,7 @@ def split_pixels( Args: input_file (FilePointer): file to read for catalog data. - file_reader (hipscat_import.catalog.file_readers.InputReader): instance + file_reader (hipscat_import.catalog.file_readers.InputReader): instance of input reader that specifies arguments necessary for reading from the input file. shard_suffix (str): unique counter for this input file, used when creating intermediate files diff --git a/src/hipscat_import/catalog/run_import.py b/src/hipscat_import/catalog/run_import.py index 273e1da7..8c5db1ad 100644 --- a/src/hipscat_import/catalog/run_import.py +++ b/src/hipscat_import/catalog/run_import.py @@ -7,7 +7,7 @@ import hipscat.io.write_metadata as io import numpy as np -from dask.distributed import Client, as_completed +from dask.distributed import as_completed from hipscat import pixel_math from tqdm import tqdm @@ -151,30 +151,12 @@ def _reduce_pixels(args, destination_pixel_map, client): resume.set_reducing_done(args.tmp_path) -def _validate_args(args): +def run(args, client): + """Run catalog creation pipeline.""" if not args: raise ValueError("args is required and should be type ImportArguments") if not isinstance(args, ImportArguments): raise ValueError("args must be type ImportArguments") - - -def run(args): - """Importer that creates a dask client from the arguments""" - _validate_args(args) - - # pylint: disable=duplicate-code - with Client( - local_directory=args.dask_tmp, - n_workers=args.dask_n_workers, - threads_per_worker=args.dask_threads_per_worker, - ) as client: # pragma: no cover - run_with_client(args, client) - # pylint: enable=duplicate-code - - -def run_with_client(args, client): - """Importer, where the client context may out-live the runner""" - _validate_args(args) raw_histogram = _map_pixels(args, client) with tqdm( diff --git a/src/hipscat_import/index/run_index.py b/src/hipscat_import/index/run_index.py index 46a7bd3e..e1912538 100644 --- a/src/hipscat_import/index/run_index.py +++ b/src/hipscat_import/index/run_index.py @@ -7,16 +7,12 @@ from hipscat_import.index.arguments import IndexArguments -def _validate_args(args): +def run(args): + """Run index creation pipeline.""" if not args: raise TypeError("args is required and should be type IndexArguments") if not isinstance(args, IndexArguments): raise TypeError("args must be type IndexArguments") - - -def run(args): - """Importer, where the client context may out-live the runner""" - _validate_args(args) rows_written = mr.create_index(args) # All done - write out the metadata diff --git a/src/hipscat_import/margin_cache/__init__.py b/src/hipscat_import/margin_cache/__init__.py index 940d08b5..93f63fe3 100644 --- a/src/hipscat_import/margin_cache/__init__.py +++ b/src/hipscat_import/margin_cache/__init__.py @@ -1,4 +1,3 @@ """All modules for generating margin caches.""" -from .margin_cache import (generate_margin_cache, - generate_margin_cache_with_client) +from .margin_cache import generate_margin_cache from .margin_cache_arguments import MarginCacheArguments diff --git a/src/hipscat_import/margin_cache/margin_cache.py b/src/hipscat_import/margin_cache/margin_cache.py index 59734bba..504ea060 100644 --- a/src/hipscat_import/margin_cache/margin_cache.py +++ b/src/hipscat_import/margin_cache/margin_cache.py @@ -1,5 +1,5 @@ import pandas as pd -from dask.distributed import Client, as_completed +from dask.distributed import as_completed from hipscat import pixel_math from hipscat.io import file_io, paths from tqdm import tqdm @@ -8,6 +8,7 @@ # pylint: disable=too-many-locals,too-many-arguments + def _find_partition_margin_pixel_pairs(stats, margin_order): """Creates a DataFrame filled with many-to-many connections between the catalog partition pixels and the margin pixels at `margin_order`. @@ -31,30 +32,28 @@ def _find_partition_margin_pixel_pairs(stats, margin_order): margin_pairs_df = pd.DataFrame( zip(norders, part_pix, margin_pix), - columns=["partition_order", "partition_pixel", "margin_pixel"] + columns=["partition_order", "partition_pixel", "margin_pixel"], ) return margin_pairs_df + def _create_margin_directory(stats, output_path): """Creates directories for all the catalog partitions.""" for healpixel in stats: order = healpixel.order pix = healpixel.pixel - destination_dir = paths.pixel_directory( - output_path, order, pix - ) + destination_dir = paths.pixel_directory(output_path, order, pix) file_io.make_directory(destination_dir, exist_ok=True) + def _map_to_margin_shards(client, args, partition_pixels, margin_pairs): """Create all the jobs for mapping partition files into the margin cache.""" futures = [] mp_future = client.scatter(margin_pairs, broadcast=True) for pix in partition_pixels: partition_file = paths.pixel_catalog_file( - args.input_catalog_path, - pix.order, - pix.pixel + args.input_catalog_path, pix.order, pix.pixel ) futures.append( client.submit( @@ -76,50 +75,27 @@ def _map_to_margin_shards(client, args, partition_pixels, margin_pairs): ): ... -def generate_margin_cache(args): + +def generate_margin_cache(args, client): """Generate a margin cache for a given input catalog. The input catalog must be in hipscat format. - This method will handle the creation of the `dask.distributed` client - based on the `dask_tmp`, `dask_n_workers`, and `dask_threads_per_worker` - values of the `MarginCacheArguments` object. Args: args (MarginCacheArguments): A valid `MarginCacheArguments` object. - """ - # pylint: disable=duplicate-code - with Client( - local_directory=args.dask_tmp, - n_workers=args.dask_n_workers, - threads_per_worker=args.dask_threads_per_worker, - ) as client: # pragma: no cover - generate_margin_cache_with_client( - client, - args - ) - # pylint: enable=duplicate-code - -def generate_margin_cache_with_client(client, args): - """Generate a margin cache for a given input catalog. - The input catalog must be in hipscat format. - Args: client (dask.distributed.Client): A dask distributed client object. - args (MarginCacheArguments): A valid `MarginCacheArguments` object. """ # determine which order to generate margin pixels for partition_stats = args.catalog.partition_info.get_healpix_pixels() margin_pairs = _find_partition_margin_pixel_pairs( - partition_stats, - args.margin_order + partition_stats, args.margin_order ) # arcsec to degree conversion # TODO: remove this once hipscat uses arcsec for calculation - args.margin_threshold = args.margin_threshold / 3600. + args.margin_threshold = args.margin_threshold / 3600.0 - _create_margin_directory( - partition_stats, args.catalog_path - ) + _create_margin_directory(partition_stats, args.catalog_path) _map_to_margin_shards( client=client, diff --git a/src/hipscat_import/margin_cache/margin_cache_map_reduce.py b/src/hipscat_import/margin_cache/margin_cache_map_reduce.py index ee83837d..8ff13c41 100644 --- a/src/hipscat_import/margin_cache/margin_cache_map_reduce.py +++ b/src/hipscat_import/margin_cache/margin_cache_map_reduce.py @@ -6,6 +6,7 @@ # pylint: disable=too-many-locals,too-many-arguments + def map_pixel_shards( partition_file, margin_pairs, @@ -13,7 +14,7 @@ def map_pixel_shards( output_path, margin_order, ra_column, - dec_column + dec_column, ): """Creates margin cache shards from a source partition file.""" data = pd.read_parquet(partition_file) @@ -38,7 +39,10 @@ def map_pixel_shards( dec_column=dec_column, ) -def _to_pixel_shard(data, margin_threshold, output_path, margin_order, ra_column, dec_column): + +def _to_pixel_shard( + data, margin_threshold, output_path, margin_order, ra_column, dec_column +): """Do boundary checking for the cached partition and then output remaining data.""" order, pix = data["partition_order"].iloc[0], data["partition_pixel"].iloc[0] source_order, source_pix = data["Norder"].iloc[0], data["Npix"].iloc[0] @@ -60,9 +64,7 @@ def _to_pixel_shard(data, margin_threshold, output_path, margin_order, ra_column ) else: data["margin_check"] = pixel_math.check_margin_bounds( - data[ra_column].values, - data[dec_column].values, - bounding_polygons + data[ra_column].values, data[dec_column].values, bounding_polygons ) # pylint: disable-next=singleton-comparison @@ -74,28 +76,27 @@ def _to_pixel_shard(data, margin_threshold, output_path, margin_order, ra_column # generate a file name for our margin shard partition_file = paths.pixel_catalog_file(output_path, order, pix) partition_dir = f"{partition_file[:-8]}/" - shard_dir = paths.pixel_directory( - partition_dir, source_order, source_pix - ) + shard_dir = paths.pixel_directory(partition_dir, source_order, source_pix) file_io.make_directory(shard_dir, exist_ok=True) - shard_path = paths.pixel_catalog_file( - partition_dir, source_order, source_pix - ) + shard_path = paths.pixel_catalog_file(partition_dir, source_order, source_pix) - final_df = margin_data.drop(columns=[ - "partition_order", - "partition_pixel", - "margin_check", - "margin_pixel" - ]) + final_df = margin_data.drop( + columns=[ + "partition_order", + "partition_pixel", + "margin_check", + "margin_pixel", + ] + ) if is_polar: final_df = final_df.drop(columns=["is_trunc"]) final_df.to_parquet(shard_path) + def _margin_filter_polar( data, order, @@ -104,13 +105,11 @@ def _margin_filter_polar( margin_threshold, ra_column, dec_column, - bounding_polygons + bounding_polygons, ): """Filter out margin data around the poles.""" trunc_pix = pixel_math.get_truncated_margin_pixels( - order=order, - pix=pix, - margin_order=margin_order + order=order, pix=pix, margin_order=margin_order ) data["is_trunc"] = np.isin(data["margin_pixel"], trunc_pix) @@ -124,14 +123,12 @@ def _margin_filter_polar( order, pix, margin_order, - margin_threshold + margin_threshold, ) data.loc[data["is_trunc"] == True, "margin_check"] = trunc_check no_trunc_check = pixel_math.check_margin_bounds( - other_data[ra_column].values, - other_data[dec_column].values, - bounding_polygons + other_data[ra_column].values, other_data[dec_column].values, bounding_polygons ) data.loc[data["is_trunc"] == False, "margin_check"] = no_trunc_check # pylint: enable=singleton-comparison diff --git a/src/hipscat_import/pipeline.py b/src/hipscat_import/pipeline.py new file mode 100644 index 00000000..eb0f47c9 --- /dev/null +++ b/src/hipscat_import/pipeline.py @@ -0,0 +1,83 @@ +"""Flow control and pipeline entry points.""" +import smtplib +from email.message import EmailMessage + +from dask.distributed import Client + +import hipscat_import.association.run_association as association_runner +import hipscat_import.catalog.run_import as catalog_runner +import hipscat_import.index.run_index as index_runner +import hipscat_import.margin_cache.margin_cache as margin_runner +from hipscat_import.association.arguments import AssociationArguments +from hipscat_import.catalog.arguments import ImportArguments +from hipscat_import.index.arguments import IndexArguments +from hipscat_import.margin_cache.margin_cache_arguments import MarginCacheArguments +from hipscat_import.runtime_arguments import RuntimeArguments + +# pragma: no cover + + +def pipeline(args: RuntimeArguments): + """Pipeline that creates its own client from the provided runtime arguments""" + with Client( + local_directory=args.dask_tmp, + n_workers=args.dask_n_workers, + threads_per_worker=args.dask_threads_per_worker, + ) as client: + pipeline_with_client(args, client) + + +def pipeline_with_client(args: RuntimeArguments, client: Client): + """Pipeline that is run using an existing client. + + This can be useful in tests, or when a dask client requires some more complex + configuraion. + """ + try: + if not args: + raise ValueError("args is required and should be subclass of RuntimeArguments") + + if isinstance(args, ImportArguments): + catalog_runner.run(args, client) + elif isinstance(args, AssociationArguments): + association_runner.run(args) + elif isinstance(args, IndexArguments): + index_runner.run(args) + elif isinstance(args, MarginCacheArguments): + margin_runner.generate_margin_cache(args, client) + else: + raise ValueError("unknown args type") + except Exception as exception: # pylint: disable=broad-exception-caught + _send_failure_email(args, exception) + else: + _send_success_email(args) + + +def _send_failure_email(args: RuntimeArguments, exception: Exception): + if not args.completion_email_address: + raise exception + message = EmailMessage() + message["Subject"] = "hipscat-import failure." + message["To"] = args.completion_email_address + message.set_content(f"failed with message:\n{exception}") + + _send_email(message) + + +def _send_success_email(args): + if not args.completion_email_address: + return + message = EmailMessage() + message["Subject"] = "hipscat-import success." + message["To"] = args.completion_email_address + message.set_content(f"output_catalog_name: {args.output_catalog_name}") + + _send_email(message) + + +def _send_email(message: EmailMessage): + message["From"] = "updates@lsdb.io" + + # Send the message via our own SMTP server. + with smtplib.SMTP("localhost") as server: + server.send_message(message) diff --git a/src/hipscat_import/runtime_arguments.py b/src/hipscat_import/runtime_arguments.py index c7253450..a2bce52a 100644 --- a/src/hipscat_import/runtime_arguments.py +++ b/src/hipscat_import/runtime_arguments.py @@ -36,6 +36,10 @@ class RuntimeArguments: dask_threads_per_worker: int = 1 """number of threads per dask worker""" + completion_email_address: str = "" + """if provided, send an email to the indicated email address once the + import pipeline has complete.""" + catalog_path = "" """constructed output path for the catalog that will be something like /""" diff --git a/tests/hipscat_import/association/test_association_map_reduce.py b/tests/hipscat_import/association/test_association_map_reduce.py index 88de77bb..8d310869 100644 --- a/tests/hipscat_import/association/test_association_map_reduce.py +++ b/tests/hipscat_import/association/test_association_map_reduce.py @@ -33,7 +33,7 @@ def test_map_association( ) subset_catalog_path = args.catalog_path - runner.run_with_client(args, dask_client) + runner.run(args, dask_client) with open( os.path.join(subset_catalog_path, "catalog_info.json"), "r", encoding="utf-8" diff --git a/tests/hipscat_import/association/test_run_association.py b/tests/hipscat_import/association/test_run_association.py index c09a31f7..628b6c6b 100644 --- a/tests/hipscat_import/association/test_run_association.py +++ b/tests/hipscat_import/association/test_run_association.py @@ -6,8 +6,7 @@ import numpy.testing as npt import pandas as pd import pytest -from hipscat.catalog.association_catalog.association_catalog import \ - AssociationCatalog +from hipscat.catalog.association_catalog.association_catalog import AssociationCatalog import hipscat_import.association.run_association as runner from hipscat_import.association.arguments import AssociationArguments diff --git a/tests/hipscat_import/catalog/test_argument_validation.py b/tests/hipscat_import/catalog/test_argument_validation.py index f7e377f4..9954415b 100644 --- a/tests/hipscat_import/catalog/test_argument_validation.py +++ b/tests/hipscat_import/catalog/test_argument_validation.py @@ -3,8 +3,7 @@ import pytest -from hipscat_import.catalog.arguments import (ImportArguments, - check_healpix_order_range) +from hipscat_import.catalog.arguments import ImportArguments, check_healpix_order_range # pylint: disable=protected-access diff --git a/tests/hipscat_import/catalog/test_map_reduce.py b/tests/hipscat_import/catalog/test_map_reduce.py index 7589b2f0..356ac78c 100644 --- a/tests/hipscat_import/catalog/test_map_reduce.py +++ b/tests/hipscat_import/catalog/test_map_reduce.py @@ -194,9 +194,7 @@ def test_reduce_order0(parquet_shards_dir, assert_parquet_file_ids, tmp_path): assert_parquet_file_ids(output_file, "id", expected_ids) -def test_reduce_hipscat_index( - parquet_shards_dir, assert_parquet_file_ids, tmp_path -): +def test_reduce_hipscat_index(parquet_shards_dir, assert_parquet_file_ids, tmp_path): """Test reducing with or without a _hipscat_index field""" mr.reduce_pixel_shards( cache_path=parquet_shards_dir, diff --git a/tests/hipscat_import/catalog/test_run_import.py b/tests/hipscat_import/catalog/test_run_import.py index 0a30b558..50b6efff 100644 --- a/tests/hipscat_import/catalog/test_run_import.py +++ b/tests/hipscat_import/catalog/test_run_import.py @@ -13,15 +13,15 @@ def test_empty_args(): """Runner should fail with empty arguments""" - with pytest.raises(ValueError): - runner.run(None) + with pytest.raises(ValueError, match="args is required"): + runner.run(None, None) def test_bad_args(): """Runner should fail with mis-typed arguments""" args = {"output_catalog_name": "bad_arg_type"} - with pytest.raises(ValueError): - runner.run(args) + with pytest.raises(ValueError, match="ImportArguments"): + runner.run(args, None) @pytest.mark.dask @@ -88,7 +88,7 @@ def test_resume_dask_runner( progress_bar=False, ) - runner.run_with_client(args, dask_client) + runner.run(args, dask_client) # Check that the catalog metadata file exists catalog = Catalog.read_from_hipscat(args.catalog_path) @@ -131,7 +131,7 @@ def test_resume_dask_runner( progress_bar=False, ) - runner.run_with_client(args, dask_client) + runner.run(args, dask_client) catalog = Catalog.read_from_hipscat(args.catalog_path) assert catalog.on_disk @@ -161,7 +161,7 @@ def test_dask_runner( progress_bar=False, ) - runner.run_with_client(args, dask_client) + runner.run(args, dask_client) # Check that the catalog metadata file exists catalog = Catalog.read_from_hipscat(args.catalog_path) @@ -195,7 +195,7 @@ def test_dask_runner_stats_only(dask_client, small_sky_parts_dir, tmp_path): debug_stats_only=True, ) - runner.run_with_client(args, dask_client) + runner.run(args, dask_client) metadata_filename = os.path.join(args.catalog_path, "catalog_info.json") assert os.path.exists(metadata_filename) diff --git a/tests/hipscat_import/catalog/test_run_round_trip.py b/tests/hipscat_import/catalog/test_run_round_trip.py index 8b050f01..d8c12fed 100644 --- a/tests/hipscat_import/catalog/test_run_round_trip.py +++ b/tests/hipscat_import/catalog/test_run_round_trip.py @@ -44,7 +44,7 @@ def test_import_source_table( progress_bar=False, ) - runner.run_with_client(args, dask_client) + runner.run(args, dask_client) # Check that the catalog metadata file exists catalog = Catalog.read_from_hipscat(args.catalog_path) @@ -84,7 +84,7 @@ def test_import_mixed_schema_csv( use_schema_file=mixed_schema_csv_parquet, ) - runner.run_with_client(args, dask_client) + runner.run(args, dask_client) # Check that the catalog parquet file exists output_file = os.path.join( @@ -140,7 +140,7 @@ def test_import_preserve_index( progress_bar=False, ) - runner.run_with_client(args, dask_client) + runner.run(args, dask_client) # Check that the catalog parquet file exists output_file = os.path.join( @@ -168,7 +168,7 @@ def test_import_preserve_index( progress_bar=False, ) - runner.run_with_client(args, dask_client) + runner.run(args, dask_client) # Check that the catalog parquet file exists output_file = os.path.join( @@ -235,7 +235,7 @@ def test_import_multiindex( progress_bar=False, ) - runner.run_with_client(args, dask_client) + runner.run(args, dask_client) # Check that the catalog parquet file exists output_file = os.path.join( @@ -263,7 +263,7 @@ def test_import_multiindex( progress_bar=False, ) - runner.run_with_client(args, dask_client) + runner.run(args, dask_client) # Check that the catalog parquet file exists output_file = os.path.join( @@ -299,14 +299,16 @@ def test_import_constant_healpix_order( progress_bar=False, ) - runner.run_with_client(args, dask_client) + runner.run(args, dask_client) # Check that the catalog metadata file exists catalog = Catalog.read_from_hipscat(args.catalog_path) assert catalog.on_disk assert catalog.catalog_path == args.catalog_path # Check that the partition info file exists - all pixels at order 2! - assert all(pixel.order == 2 for pixel in catalog.partition_info.get_healpix_pixels()) + assert all( + pixel.order == 2 for pixel in catalog.partition_info.get_healpix_pixels() + ) # Pick a parquet file and make sure it contains as many rows as we expect output_file = os.path.join( diff --git a/tests/hipscat_import/margin_cache/test_margin_cache.py b/tests/hipscat_import/margin_cache/test_margin_cache.py index 06672e1a..333acb02 100644 --- a/tests/hipscat_import/margin_cache/test_margin_cache.py +++ b/tests/hipscat_import/margin_cache/test_margin_cache.py @@ -7,6 +7,8 @@ import hipscat_import.margin_cache.margin_cache as mc from hipscat_import.margin_cache import MarginCacheArguments +# pylint: disable=protected-access + @pytest.mark.dask(timeout=20) def test_margin_cache_gen(small_sky_source_catalog, tmp_path, dask_client): @@ -20,10 +22,11 @@ def test_margin_cache_gen(small_sky_source_catalog, tmp_path, dask_client): assert args.catalog.catalog_info.ra_column == "source_ra" - mc.generate_margin_cache_with_client(dask_client, args) + mc.generate_margin_cache(args, dask_client) # TODO: add more verification of output to this test once the # reduce phase is implemented. + def test_partition_margin_pixel_pairs(small_sky_source_catalog, tmp_path): args = MarginCacheArguments( margin_threshold=5.0, @@ -33,8 +36,7 @@ def test_partition_margin_pixel_pairs(small_sky_source_catalog, tmp_path): ) margin_pairs = mc._find_partition_margin_pixel_pairs( - args.catalog.partition_info.get_healpix_pixels(), - args.margin_order + args.catalog.partition_info.get_healpix_pixels(), args.margin_order ) expected = np.array([725, 733, 757, 765, 727, 735, 759, 767, 469, 192]) @@ -42,6 +44,7 @@ def test_partition_margin_pixel_pairs(small_sky_source_catalog, tmp_path): npt.assert_array_equal(margin_pairs.iloc[:10]["margin_pixel"], expected) assert len(margin_pairs) == 196 + def test_create_margin_directory(small_sky_source_catalog, tmp_path): args = MarginCacheArguments( margin_threshold=5.0, @@ -52,10 +55,8 @@ def test_create_margin_directory(small_sky_source_catalog, tmp_path): mc._create_margin_directory( stats=args.catalog.partition_info.get_healpix_pixels(), - output_path=args.catalog_path + output_path=args.catalog_path, ) - output = file_io.append_paths_to_pointer( - args.catalog_path, "Norder=0", "Dir=0" - ) + output = file_io.append_paths_to_pointer(args.catalog_path, "Norder=0", "Dir=0") assert file_io.does_file_or_directory_exist(output) diff --git a/tests/hipscat_import/margin_cache/test_margin_cache_map_reduce.py b/tests/hipscat_import/margin_cache/test_margin_cache_map_reduce.py index b58b7c29..403302ed 100644 --- a/tests/hipscat_import/margin_cache/test_margin_cache_map_reduce.py +++ b/tests/hipscat_import/margin_cache/test_margin_cache_map_reduce.py @@ -6,19 +6,17 @@ from hipscat_import.margin_cache import margin_cache_map_reduce -keep_cols = [ - "weird_ra", - "weird_dec" -] +keep_cols = ["weird_ra", "weird_dec"] drop_cols = [ - "partition_order", - "partition_pixel", - "margin_check", + "partition_order", + "partition_pixel", + "margin_check", "margin_pixel", - "is_trunc" + "is_trunc", ] + def validate_result_dataframe(df_path, expected_len): res_df = pd.read_parquet(df_path) @@ -32,10 +30,11 @@ def validate_result_dataframe(df_path, expected_len): for col in drop_cols: assert col not in cols + @pytest.mark.timeout(5) def test_to_pixel_shard_equator(tmp_path): - ras = np.arange(0.,360.) - dec = np.full(360, 0.) + ras = np.arange(0.0, 360.0) + dec = np.full(360, 0.0) ppix = np.full(360, 21) porder = np.full(360, 1) norder = np.full(360, 1) @@ -44,13 +43,13 @@ def test_to_pixel_shard_equator(tmp_path): test_df = pd.DataFrame( data=zip(ras, dec, ppix, porder, norder, npix), columns=[ - "weird_ra", + "weird_ra", "weird_dec", "partition_pixel", "partition_order", "Norder", - "Npix" - ] + "Npix", + ], ) test_df["margin_pixel"] = hp.ang2pix( @@ -58,7 +57,7 @@ def test_to_pixel_shard_equator(tmp_path): test_df["weird_ra"].values, test_df["weird_dec"].values, lonlat=True, - nest=True + nest=True, ) margin_cache_map_reduce._to_pixel_shard( @@ -67,21 +66,21 @@ def test_to_pixel_shard_equator(tmp_path): output_path=tmp_path, margin_order=3, ra_column="weird_ra", - dec_column="weird_dec" + dec_column="weird_dec", ) path = file_io.append_paths_to_pointer( - tmp_path, - "Norder=1/Dir=0/Npix=21/Norder=1/Dir=0/Npix=0.parquet" + tmp_path, "Norder=1/Dir=0/Npix=21/Norder=1/Dir=0/Npix=0.parquet" ) assert file_io.does_file_or_directory_exist(path) validate_result_dataframe(path, 46) + @pytest.mark.timeout(5) def test_to_pixel_shard_polar(tmp_path): - ras = np.arange(0.,360.) + ras = np.arange(0.0, 360.0) dec = np.full(360, 89.9) ppix = np.full(360, 15) porder = np.full(360, 2) @@ -91,13 +90,13 @@ def test_to_pixel_shard_polar(tmp_path): test_df = pd.DataFrame( data=zip(ras, dec, ppix, porder, norder, npix), columns=[ - "weird_ra", + "weird_ra", "weird_dec", "partition_pixel", "partition_order", "Norder", - "Npix" - ] + "Npix", + ], ) test_df["margin_pixel"] = hp.ang2pix( @@ -105,7 +104,7 @@ def test_to_pixel_shard_polar(tmp_path): test_df["weird_ra"].values, test_df["weird_dec"].values, lonlat=True, - nest=True + nest=True, ) margin_cache_map_reduce._to_pixel_shard( @@ -114,12 +113,11 @@ def test_to_pixel_shard_polar(tmp_path): output_path=tmp_path, margin_order=3, ra_column="weird_ra", - dec_column="weird_dec" + dec_column="weird_dec", ) path = file_io.append_paths_to_pointer( - tmp_path, - "Norder=2/Dir=0/Npix=15/Norder=2/Dir=0/Npix=0.parquet" + tmp_path, "Norder=2/Dir=0/Npix=15/Norder=2/Dir=0/Npix=0.parquet" ) assert file_io.does_file_or_directory_exist(path)