Skip to content

Commit

Permalink
Macauff pipeline boilerplate (#152)
Browse files Browse the repository at this point in the history
* checkpoint

* checkpoint

* testing for MacauffArguments

* create boilerplate for macauff runner + tests

* remove commented out test

* add __future__.annotations

* linter problems

* add more tests for missing coverage

* refactor MacauffArguments required parameter tests

* address more comments from pr #152

* add dask pytest mark + black formatting
  • Loading branch information
maxwest-uw committed Oct 30, 2023
1 parent f8ede50 commit 46ee668
Show file tree
Hide file tree
Showing 7 changed files with 421 additions and 0 deletions.
128 changes: 128 additions & 0 deletions src/hipscat_import/cross_match/macauff_arguments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from __future__ import annotations

from dataclasses import dataclass, field
from os import path
from typing import List

from hipscat.io import FilePointer, file_io
from hipscat.io.validation import is_valid_catalog

from hipscat_import.runtime_arguments import RuntimeArguments

# pylint: disable=too-many-instance-attributes
# pylint: disable=unsupported-binary-operation


@dataclass
class MacauffArguments(RuntimeArguments):
"""Data class for holding cross-match association arguments"""

## Input - Cross-match data
input_path: FilePointer | None = None
"""path to search for the input data"""
input_format: str = ""
"""specifier of the input data format. this will be used to find an appropriate
InputReader type, and may be used to find input files, via a match like
``<input_path>/*<input_format>`` """
input_file_list: List[FilePointer] = field(default_factory=list)
"""can be used instead of `input_format` to import only specified files"""
input_paths: List[FilePointer] = field(default_factory=list)
"""resolved list of all files that will be used in the importer"""
add_hipscat_index: bool = True
"""add the hipscat spatial index field alongside the data"""

## Input - Left catalog
left_catalog_dir: str = ""
left_id_column: str = ""
left_ra_column: str = ""
left_dec_column: str = ""

## Input - Right catalog
right_catalog_dir: str = ""
right_id_column: str = ""
right_ra_column: str = ""
right_dec_column: str = ""

## `macauff` specific attributes
metadata_file_path: str = ""
match_probability_columns: List[str] = field(default_factory=list)
column_names: List[str] = field(default_factory=list)

def __post_init__(self):
self._check_arguments()

def _check_arguments(self):
super()._check_arguments()

if not self.input_path and not self.input_file_list:
raise ValueError("input_path nor input_file_list not provided")
if not self.input_format:
raise ValueError("input_format is required")

if not self.left_catalog_dir:
raise ValueError("left_catalog_dir is required")
if not self.left_id_column:
raise ValueError("left_id_column is required")
if not self.left_ra_column:
raise ValueError("left_ra_column is required")
if not self.left_dec_column:
raise ValueError("left_dec_column is required")
if not is_valid_catalog(self.left_catalog_dir):
raise ValueError("left_catalog_dir not a valid catalog")

if not self.right_catalog_dir:
raise ValueError("right_catalog_dir is required")
if not self.right_id_column:
raise ValueError("right_id_column is required")
if not self.right_ra_column:
raise ValueError("right_ra_column is required")
if not self.right_dec_column:
raise ValueError("right_dec_column is required")
if not is_valid_catalog(self.right_catalog_dir):
raise ValueError("right_catalog_dir not a valid catalog")

if not self.metadata_file_path:
raise ValueError("metadata_file_path required for macauff crossmatch")
if not path.isfile(self.metadata_file_path):
raise ValueError("Macauff column metadata file must point to valid file path.")

# Basic checks complete - make more checks and create directories where necessary
if self.input_path:
if not file_io.does_file_or_directory_exist(self.input_path):
raise FileNotFoundError("input_path not found on local storage")
self.input_paths = file_io.find_files_matching_path(self.input_path, f"*{self.input_format}")
elif self.input_file_list:
self.input_paths = self.input_file_list
if len(self.input_paths) == 0:
raise FileNotFoundError("No input files found")

self.column_names = self.get_column_names()

def get_column_names(self):
"""Grab the macauff column names."""
# TODO: Actually read in the metadata file once we get the example file from Tom.

return [
"Gaia_designation",
"Gaia_RA",
"Gaia_Dec",
"BP",
"G",
"RP",
"CatWISE_Name",
"CatWISE_RA",
"CatWISE_Dec",
"W1",
"W2",
"match_p",
"Separation",
"eta",
"xi",
"Gaia_avg_cont",
"CatWISE_avg_cont",
"Gaia_cont_f1",
"Gaia_cont_f10",
"CatWISE_cont_f1",
"CatWISE_cont_f10",
"CatWISE_fit_sig",
]
13 changes: 13 additions & 0 deletions src/hipscat_import/cross_match/run_macauff_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from hipscat_import.cross_match.macauff_arguments import MacauffArguments

# pylint: disable=unused-argument


def run(args, client):
"""run macauff cross-match import pipeline"""
if not args:
raise TypeError("args is required and should be type MacauffArguments")
if not isinstance(args, MacauffArguments):
raise TypeError("args must be type MacauffArguments")

raise NotImplementedError("macauff pipeline not implemented yet.")
4 changes: 4 additions & 0 deletions src/hipscat_import/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from dask.distributed import Client

import hipscat_import.catalog.run_import as catalog_runner
import hipscat_import.cross_match.run_macauff_import as macauff_runner
import hipscat_import.index.run_index as index_runner
import hipscat_import.margin_cache.margin_cache as margin_runner
import hipscat_import.soap.run_soap as soap_runner
import hipscat_import.verification.run_verification as verification_runner
from hipscat_import.catalog.arguments import ImportArguments
from hipscat_import.cross_match.macauff_arguments import MacauffArguments
from hipscat_import.index.arguments import IndexArguments
from hipscat_import.margin_cache.margin_cache_arguments import MarginCacheArguments
from hipscat_import.runtime_arguments import RuntimeArguments
Expand Down Expand Up @@ -49,6 +51,8 @@ def pipeline_with_client(args: RuntimeArguments, client: Client):
soap_runner.run(args, client)
elif isinstance(args, VerificationArguments):
verification_runner.run(args)
elif isinstance(args, MacauffArguments):
macauff_runner.run(args, client)
else:
raise ValueError("unknown args type")
except Exception as exception: # pylint: disable=broad-exception-caught
Expand Down
5 changes: 5 additions & 0 deletions tests/hipscat_import/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,11 @@ def formats_multiindex(test_data_dir):
return os.path.join(test_data_dir, "test_formats", "multiindex.parquet")


@pytest.fixture
def formats_yaml(test_data_dir):
return os.path.join(test_data_dir, "test_formats", "macauff_metadata.yaml")


@pytest.fixture
def small_sky_parts_dir(test_data_dir):
return os.path.join(test_data_dir, "small_sky_parts")
Expand Down
219 changes: 219 additions & 0 deletions tests/hipscat_import/cross_match/test_macauff_arguments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
"""Tests of macauff arguments"""


from os import path

import pytest

from hipscat_import.cross_match.macauff_arguments import MacauffArguments

# pylint: disable=duplicate-code


def test_macauff_arguments(
small_sky_object_catalog, small_sky_source_catalog, small_sky_dir, formats_yaml, tmp_path
):
"""Test that we can create a MacauffArguments instance with two valid catalogs."""
args = MacauffArguments(
output_path=tmp_path,
output_catalog_name="object_to_source",
tmp_dir=tmp_path,
left_catalog_dir=small_sky_object_catalog,
left_ra_column="ra",
left_dec_column="dec",
left_id_column="id",
right_catalog_dir=small_sky_source_catalog,
right_ra_column="source_ra",
right_dec_column="source_dec",
right_id_column="source_id",
input_path=small_sky_dir,
input_format="csv",
metadata_file_path=formats_yaml,
)

assert len(args.input_paths) > 0


def test_empty_required(
small_sky_object_catalog, small_sky_source_catalog, small_sky_dir, formats_yaml, tmp_path
):
"""All non-runtime arguments are required."""

## List of required args:
## - match expression that should be found when missing
## - default value
required_args = [
["output_path", tmp_path],
["output_catalog_name", "object_to_source"],
["left_catalog_dir", small_sky_object_catalog],
["left_ra_column", "ra"],
["left_dec_column", "dec"],
["left_id_column", "id"],
["right_catalog_dir", small_sky_source_catalog],
["right_ra_column", "source_ra"],
["right_dec_column", "source_dec"],
["right_id_column", "source_id"],
["input_path", small_sky_dir],
["input_format", "csv"],
["metadata_file_path", formats_yaml],
]

## For each required argument, check that a ValueError is raised that matches the
## expected name of the missing param.
for index, args in enumerate(required_args):
test_args = [
list_args[1] if list_index != index else None
for list_index, list_args in enumerate(required_args)
]

print(f"testing required arg #{index}")

with pytest.raises(ValueError, match=args[0]):
MacauffArguments(
output_path=test_args[0],
output_catalog_name=test_args[1],
tmp_dir=tmp_path,
left_catalog_dir=test_args[2],
left_ra_column=test_args[3],
left_dec_column=test_args[4],
left_id_column=test_args[5],
right_catalog_dir=test_args[6],
right_ra_column=test_args[7],
right_dec_column=test_args[8],
right_id_column=test_args[9],
input_path=test_args[10],
input_format=test_args[11],
metadata_file_path=test_args[12],
overwrite=True,
)


def test_macauff_arguments_file_list(
small_sky_object_catalog, small_sky_source_catalog, small_sky_dir, formats_yaml, tmp_path
):
"""Test that we can create a MacauffArguments instance with two valid catalogs."""
files = [path.join(small_sky_dir, "catalog.csv")]
args = MacauffArguments(
output_path=tmp_path,
output_catalog_name="object_to_source",
tmp_dir=tmp_path,
left_catalog_dir=small_sky_object_catalog,
left_ra_column="ra",
left_dec_column="dec",
left_id_column="id",
right_catalog_dir=small_sky_source_catalog,
right_ra_column="source_ra",
right_dec_column="source_dec",
right_id_column="source_id",
input_file_list=files,
input_format="csv",
metadata_file_path=formats_yaml,
)

assert len(args.input_paths) > 0


def test_macauff_args_invalid_catalog(small_sky_source_catalog, small_sky_dir, formats_yaml, tmp_path):
with pytest.raises(ValueError, match="left_catalog_dir"):
MacauffArguments(
output_path=tmp_path,
output_catalog_name="object_to_source",
tmp_dir=tmp_path,
left_catalog_dir=small_sky_dir, # valid path, but not a catalog
left_ra_column="ra",
left_dec_column="dec",
left_id_column="id",
right_catalog_dir=small_sky_source_catalog,
right_ra_column="source_ra",
right_dec_column="source_dec",
right_id_column="source_id",
input_path=small_sky_dir,
input_format="csv",
metadata_file_path=formats_yaml,
)


def test_macauff_args_right_invalid_catalog(small_sky_object_catalog, small_sky_dir, formats_yaml, tmp_path):
with pytest.raises(ValueError, match="right_catalog_dir"):
MacauffArguments(
output_path=tmp_path,
output_catalog_name="object_to_source",
tmp_dir=tmp_path,
left_catalog_dir=small_sky_object_catalog,
left_ra_column="ra",
left_dec_column="dec",
left_id_column="id",
right_catalog_dir=small_sky_dir, # valid directory with files, not a catalog
right_ra_column="source_ra",
right_dec_column="source_dec",
right_id_column="source_id",
input_path=small_sky_dir,
input_format="csv",
metadata_file_path=formats_yaml,
)


def test_macauff_args_invalid_metadata_file(
small_sky_object_catalog, small_sky_source_catalog, small_sky_dir, tmp_path
):
with pytest.raises(ValueError, match="column metadata file must"):
MacauffArguments(
output_path=tmp_path,
output_catalog_name="object_to_source",
tmp_dir=tmp_path,
left_catalog_dir=small_sky_object_catalog,
left_ra_column="ra",
left_dec_column="dec",
left_id_column="id",
right_catalog_dir=small_sky_source_catalog,
right_ra_column="source_ra",
right_dec_column="source_dec",
right_id_column="source_id",
input_path=small_sky_dir,
input_format="csv",
metadata_file_path="ceci_n_est_pas_un_fichier.xml",
)


def test_macauff_args_invalid_input_directory(
small_sky_object_catalog, small_sky_source_catalog, formats_yaml, tmp_path
):
with pytest.raises(FileNotFoundError, match="input_path not found"):
MacauffArguments(
output_path=tmp_path,
output_catalog_name="object_to_source",
tmp_dir=tmp_path,
left_catalog_dir=small_sky_object_catalog,
left_ra_column="ra",
left_dec_column="dec",
left_id_column="id",
right_catalog_dir=small_sky_source_catalog,
right_ra_column="source_ra",
right_dec_column="source_dec",
right_id_column="source_id",
input_path="ceci_n_est_pas_un_directoire/",
input_format="csv",
metadata_file_path=formats_yaml,
)


def test_macauff_args_no_files(
small_sky_object_catalog, small_sky_source_catalog, small_sky_dir, formats_yaml, tmp_path
):
with pytest.raises(FileNotFoundError, match="No input files found"):
MacauffArguments(
output_path=tmp_path,
output_catalog_name="object_to_source",
tmp_dir=tmp_path,
left_catalog_dir=small_sky_object_catalog,
left_ra_column="ra",
left_dec_column="dec",
left_id_column="id",
right_catalog_dir=small_sky_source_catalog,
right_ra_column="source_ra",
right_dec_column="source_dec",
right_id_column="source_id",
input_path=small_sky_dir,
input_format="parquet", # no files of this format will be found
metadata_file_path=formats_yaml,
)
Loading

0 comments on commit 46ee668

Please sign in to comment.