Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Macauff import pipeline #186

Merged
merged 6 commits into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 1 addition & 30 deletions src/hipscat_import/catalog/resume_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ class ResumePlan(PipelineResumePlan):
SPLITTING_STAGE = "splitting"
REDUCING_STAGE = "reducing"

ORIGINAL_INPUT_PATHS = "input_paths.txt"

HISTOGRAM_BINARY_FILE = "mapping_histogram.binary"
HISTOGRAMS_DIR = "histograms"

Expand Down Expand Up @@ -63,15 +61,7 @@ def gather_plan(self):
step_progress.update(1)

## Validate that we're operating on the same file set as the previous instance.
unique_file_paths = set(self.input_paths)
self.input_paths = list(unique_file_paths)
self.input_paths.sort()
original_input_paths = self.get_original_paths()
if not original_input_paths:
self.save_original_paths()
else:
if original_input_paths != unique_file_paths:
raise ValueError("Different file set from resumed pipeline execution.")
self.input_paths = self.check_original_input_paths(self.input_paths)
step_progress.update(1)

## Gather keys for execution.
Expand All @@ -97,25 +87,6 @@ def gather_plan(self):
)
step_progress.update(1)

def get_original_paths(self):
"""Get all input file paths from the first pipeline attempt."""
file_path = file_io.append_paths_to_pointer(self.tmp_path, self.ORIGINAL_INPUT_PATHS)
try:
with open(file_path, "r", encoding="utf-8") as file_handle:
contents = file_handle.readlines()
contents = [path.strip() for path in contents]
original_input_paths = set(contents)
return original_input_paths
except FileNotFoundError:
return []

def save_original_paths(self):
"""Save input file paths from the first pipeline attempt."""
file_path = file_io.append_paths_to_pointer(self.tmp_path, self.ORIGINAL_INPUT_PATHS)
with open(file_path, "w", encoding="utf-8") as file_handle:
for path in self.input_paths:
file_handle.write(f"{path}\n")

def get_remaining_map_keys(self):
"""Gather remaining keys, dropping successful mapping tasks from histogram names.

Expand Down
72 changes: 38 additions & 34 deletions src/hipscat_import/cross_match/macauff_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
from os import path
from typing import List

from hipscat.catalog.association_catalog.association_catalog import AssociationCatalogInfo
from hipscat.catalog.catalog_type import CatalogType
from hipscat.io import FilePointer
from hipscat.io.validation import is_valid_catalog

from hipscat_import.catalog.file_readers import InputReader, get_file_reader
from hipscat_import.runtime_arguments import RuntimeArguments, find_input_paths

# pylint: disable=too-many-instance-attributes
Expand All @@ -28,8 +31,6 @@ class MacauffArguments(RuntimeArguments):
"""can be used instead of `input_format` to import only specified files"""
input_paths: List[FilePointer] = field(default_factory=list)
"""resolved list of all files that will be used in the importer"""
add_hipscat_index: bool = True
"""add the hipscat spatial index field alongside the data"""

## Input - Left catalog
left_catalog_dir: str = ""
Expand All @@ -45,8 +46,11 @@ class MacauffArguments(RuntimeArguments):

## `macauff` specific attributes
metadata_file_path: str = ""
match_probability_columns: List[str] = field(default_factory=list)
column_names: List[str] = field(default_factory=list)
resume: bool = True
"""if there are existing intermediate resume files, should we
read those and continue to create a new catalog where we left off"""

file_reader: InputReader | None = None

def __post_init__(self):
self._check_arguments()
Expand Down Expand Up @@ -89,33 +93,33 @@ def _check_arguments(self):
# Basic checks complete - make more checks and create directories where necessary
self.input_paths = find_input_paths(self.input_path, f"*{self.input_format}", self.input_file_list)

self.column_names = self.get_column_names()

def get_column_names(self):
"""Grab the macauff column names."""
# TODO: Actually read in the metadata file once we get the example file from Tom.

return [
"Gaia_designation",
"Gaia_RA",
"Gaia_Dec",
"BP",
"G",
"RP",
"CatWISE_Name",
"CatWISE_RA",
"CatWISE_Dec",
"W1",
"W2",
"match_p",
"Separation",
"eta",
"xi",
"Gaia_avg_cont",
"CatWISE_avg_cont",
"Gaia_cont_f1",
"Gaia_cont_f10",
"CatWISE_cont_f1",
"CatWISE_cont_f10",
"CatWISE_fit_sig",
]
if not self.file_reader:
self.file_reader = get_file_reader(file_format=self.input_format)

def to_catalog_info(self, total_rows) -> AssociationCatalogInfo:
"""Catalog-type-specific dataset info."""
info = {
"catalog_name": self.output_artifact_name,
"catalog_type": CatalogType.ASSOCIATION,
"total_rows": total_rows,
"primary_column": self.left_id_column,
"primary_catalog": str(self.left_catalog_dir),
"join_column": self.right_id_column,
"join_catalog": str(self.right_catalog_dir),
}
return AssociationCatalogInfo(**info)

def additional_runtime_provenance_info(self) -> dict:
return {
"input_path": self.input_path,
"input_format": self.input_format,
"left_catalog_dir": self.left_catalog_dir,
"left_id_column": self.left_id_column,
"left_ra_column": self.left_ra_column,
"left_dec_column": self.left_dec_column,
"right_catalog_dir": self.right_catalog_dir,
"right_id_column": self.right_id_column,
"right_ra_column": self.right_ra_column,
"right_dec_column": self.right_dec_column,
"metadata_file_path": self.metadata_file_path,
}
106 changes: 106 additions & 0 deletions src/hipscat_import/cross_match/macauff_map_reduce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import healpy as hp
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
from hipscat.io import file_io, paths
from hipscat.pixel_math.healpix_pixel import HealpixPixel
from hipscat.pixel_math.healpix_pixel_function import get_pixel_argsort

from hipscat_import.catalog.map_reduce import _get_pixel_directory, _iterate_input_file
from hipscat_import.cross_match.macauff_resume_plan import MacauffResumePlan

# pylint: disable=too-many-arguments,too-many-locals


def split_associations(
input_file,
file_reader,
splitting_key,
highest_left_order,
highest_right_order,
left_alignment,
right_alignment,
left_ra_column,
left_dec_column,
right_ra_column,
right_dec_column,
tmp_path,
):
"""Map a file of links to their healpix pixels and split into shards.


Raises:
ValueError: if the `ra_column` or `dec_column` cannot be found in the input file.
FileNotFoundError: if the file does not exist, or is a directory
"""
for chunk_number, data, mapped_left_pixels in _iterate_input_file(
input_file, file_reader, highest_left_order, left_ra_column, left_dec_column, False
):
aligned_left_pixels = left_alignment[mapped_left_pixels]
unique_pixels, unique_inverse = np.unique(aligned_left_pixels, return_inverse=True)

mapped_right_pixels = hp.ang2pix(
2**highest_right_order,
data[right_ra_column].values,
data[right_dec_column].values,
lonlat=True,
nest=True,
)
aligned_right_pixels = right_alignment[mapped_right_pixels]

data["Norder"] = [pix.order for pix in aligned_left_pixels]
data["Dir"] = [pix.dir for pix in aligned_left_pixels]
data["Npix"] = [pix.pixel for pix in aligned_left_pixels]

data["join_Norder"] = [pix.order for pix in aligned_right_pixels]
data["join_Dir"] = [pix.dir for pix in aligned_right_pixels]
data["join_Npix"] = [pix.pixel for pix in aligned_right_pixels]

for unique_index, pixel in enumerate(unique_pixels):
mapped_indexes = np.where(unique_inverse == unique_index)
data_indexes = data.index[mapped_indexes[0].tolist()]

filtered_data = data.filter(items=data_indexes, axis=0)

pixel_dir = _get_pixel_directory(tmp_path, pixel.order, pixel.pixel)
file_io.make_directory(pixel_dir, exist_ok=True)
output_file = file_io.append_paths_to_pointer(
pixel_dir, f"shard_{splitting_key}_{chunk_number}.parquet"
)
filtered_data.to_parquet(output_file, index=False)
del data, filtered_data, data_indexes

MacauffResumePlan.splitting_key_done(tmp_path=tmp_path, splitting_key=splitting_key)


def reduce_associations(left_pixel, tmp_path, catalog_path, reduce_key):
"""For all points determined to be in the target left_pixel, map them to the appropriate right_pixel
and aggregate into a single parquet file."""
inputs = _get_pixel_directory(tmp_path, left_pixel.order, left_pixel.pixel)

if not file_io.directory_has_contents(inputs):
MacauffResumePlan.reducing_key_done(
tmp_path=tmp_path, reducing_key=f"{left_pixel.order}_{left_pixel.pixel}"
)
print(f"Warning: no input data for pixel {left_pixel}")
return
destination_dir = paths.pixel_directory(catalog_path, left_pixel.order, left_pixel.pixel)
file_io.make_directory(destination_dir, exist_ok=True)

destination_file = paths.pixel_catalog_file(catalog_path, left_pixel.order, left_pixel.pixel)

merged_table = pq.read_table(inputs)
dataframe = merged_table.to_pandas().reset_index()

## One row group per join_Norder/join_Npix

join_pixel_frames = dataframe.groupby(["join_Norder", "join_Npix"], group_keys=True)
join_pixels = [HealpixPixel(pixel[0], pixel[1]) for pixel, _ in join_pixel_frames]
pixel_argsort = get_pixel_argsort(join_pixels)
with pq.ParquetWriter(destination_file, merged_table.schema) as writer:
for pixel_index in pixel_argsort:
join_pixel = join_pixels[pixel_index]
join_pixel_frame = join_pixel_frames.get_group((join_pixel.order, join_pixel.pixel)).reset_index()
writer.write_table(pa.Table.from_pandas(join_pixel_frame, schema=merged_table.schema))

MacauffResumePlan.reducing_key_done(tmp_path=tmp_path, reducing_key=reduce_key)
2 changes: 1 addition & 1 deletion src/hipscat_import/cross_match/macauff_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def from_yaml(input_file, output_directory):
table_name = table.get("name", f"metadata_table_{index}")
for col_index, column in enumerate(table.get("columns", [])):
name = column.get("name", f"column_{col_index}")
units = column.get("units", "string")
units = column.get("datatype", "string")
fields.append(_construct_field(name, units, metadata_dict=column))

schema = pa.schema(fields)
Expand Down
Loading