Skip to content

Commit

Permalink
Margin generation improvements. (#356)
Browse files Browse the repository at this point in the history
  • Loading branch information
delucchi-cmu authored Jul 31, 2024
1 parent ee02e32 commit 33b89b2
Show file tree
Hide file tree
Showing 10 changed files with 878 additions and 79 deletions.
8 changes: 2 additions & 6 deletions src/hipscat_import/margin_cache/margin_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def generate_margin_cache(args, client):
margin_order=args.margin_order,
ra_column=args.catalog.catalog_info.ra_column,
dec_column=args.catalog.catalog_info.dec_column,
fine_filtering=args.fine_filtering,
)
)
resume_plan.wait_for_mapping(futures)
Expand All @@ -59,16 +60,11 @@ def generate_margin_cache(args, client):
resume_plan.wait_for_reducing(futures)

with resume_plan.print_progress(total=4, stage_name="Finishing") as step_progress:
parquet_metadata.write_parquet_metadata(
total_rows = parquet_metadata.write_parquet_metadata(
args.catalog_path, storage_options=args.output_storage_options
)
step_progress.update(1)
total_rows = 0
metadata_path = paths.get_parquet_metadata_pointer(args.catalog_path)
for row_group in parquet_metadata.read_row_group_fragments(
metadata_path, storage_options=args.output_storage_options
):
total_rows += row_group.num_rows
partition_info = PartitionInfo.read_from_file(
metadata_path, storage_options=args.output_storage_options
)
Expand Down
3 changes: 3 additions & 0 deletions src/hipscat_import/margin_cache/margin_cache_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ class MarginCacheArguments(RuntimeArguments):
order of healpix partitioning in the source catalog. if `margin_order` is left
default or set to -1, then the `margin_order` will be set dynamically to the
highest partition order plus 1."""
fine_filtering: bool = True
"""should we perform the precise boundary checking? if false, some results may be
greater than `margin_threshold` away from the border (but within `margin_order`)."""
delete_intermediate_parquet_files: bool = True
"""should we delete the smaller intermediate parquet files generated in the
splitting stage, once the relevant reducing stage is complete?"""
Expand Down
102 changes: 63 additions & 39 deletions src/hipscat_import/margin_cache/margin_cache_map_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from hipscat.catalog.partition_info import PartitionInfo
from hipscat.io import file_io, paths
from hipscat.pixel_math.healpix_pixel import HealpixPixel
from hipscat.pixel_math.hipscat_id import HIPSCAT_ID_COLUMN

from hipscat_import.margin_cache.margin_cache_resume_plan import MarginCachePlan
from hipscat_import.pipeline_resume_plan import get_pixel_cache_directory, print_task_failure


# pylint: disable=too-many-arguments
def map_pixel_shards(
partition_file,
mapping_key,
Expand All @@ -24,6 +24,7 @@ def map_pixel_shards(
margin_order,
ra_column,
dec_column,
fine_filtering,
):
"""Creates margin cache shards from a source partition file."""
try:
Expand All @@ -33,25 +34,47 @@ def map_pixel_shards(
data = file_io.read_parquet_file_to_pandas(
partition_file, schema=schema, storage_options=input_storage_options
)
source_pixel = HealpixPixel(data["Norder"].iloc[0], data["Npix"].iloc[0])

data["margin_pixel"] = hp.ang2pix(
# Constrain the possible margin pairs, first by only those `margin_order` pixels
# that **can** be contained in source pixel, then by `margin_order` pixels for rows
# in source data
margin_pairs = pd.read_csv(margin_pair_file)
explosion_factor = 4 ** (margin_order - source_pixel.order)
margin_pixel_range_start = source_pixel.pixel * explosion_factor
margin_pixel_range_end = (source_pixel.pixel + 1) * explosion_factor
margin_pairs = margin_pairs.query(
f"margin_pixel >= {margin_pixel_range_start} and margin_pixel < {margin_pixel_range_end}"
)

margin_pixel_list = hp.ang2pix(
2**margin_order,
data[ra_column].values,
data[dec_column].values,
lonlat=True,
nest=True,
)

margin_pairs = pd.read_csv(margin_pair_file)
constrained_data = data.reset_index().merge(margin_pairs, on="margin_pixel")

if len(constrained_data):
constrained_data.groupby(["partition_order", "partition_pixel"]).apply(
_to_pixel_shard,
margin_pixel_filter = pd.DataFrame(
{"margin_pixel": margin_pixel_list, "filter_value": np.arange(0, len(margin_pixel_list))}
).merge(margin_pairs, on="margin_pixel")

# For every possible output pixel, find the full margin_order pixel filter list,
# perform the filter, and pass along to helper method to compute fine filter
# and write out shard file.
for partition_key, data_filter in margin_pixel_filter.groupby(["partition_order", "partition_pixel"]):
data_filter = np.unique(data_filter["filter_value"]).tolist()
pixel = HealpixPixel(partition_key[0], partition_key[1])

filtered_data = data.iloc[data_filter]
_to_pixel_shard(
filtered_data=filtered_data,
pixel=pixel,
margin_threshold=margin_threshold,
output_path=output_path,
ra_column=ra_column,
dec_column=dec_column,
source_pixel=source_pixel,
fine_filtering=fine_filtering,
)

MarginCachePlan.mapping_key_done(output_path, mapping_key)
Expand All @@ -60,60 +83,61 @@ def map_pixel_shards(
raise exception


def _to_pixel_shard(data, margin_threshold, output_path, ra_column, dec_column):
def _to_pixel_shard(
filtered_data,
pixel,
margin_threshold,
output_path,
ra_column,
dec_column,
source_pixel,
fine_filtering,
):
"""Do boundary checking for the cached partition and then output remaining data."""
order, pix = data["partition_order"].iloc[0], data["partition_pixel"].iloc[0]
source_order, source_pix = data["Norder"].iloc[0], data["Npix"].iloc[0]

data["margin_check"] = pixel_math.check_margin_bounds(
data[ra_column].values, data[dec_column].values, order, pix, margin_threshold
)
if fine_filtering:
margin_check = pixel_math.check_margin_bounds(
filtered_data[ra_column].values,
filtered_data[dec_column].values,
pixel.order,
pixel.pixel,
margin_threshold,
)

# pylint: disable-next=singleton-comparison
margin_data = data.loc[data["margin_check"] == True]
margin_data = filtered_data.iloc[margin_check]
else:
margin_data = filtered_data

if len(margin_data):
# generate a file name for our margin shard, that uses both sets of Norder/Npix
partition_dir = get_pixel_cache_directory(output_path, HealpixPixel(order, pix))
shard_dir = paths.pixel_directory(partition_dir, source_order, source_pix)
partition_dir = get_pixel_cache_directory(output_path, pixel)
shard_dir = paths.pixel_directory(partition_dir, source_pixel.order, source_pixel.pixel)

file_io.make_directory(shard_dir, exist_ok=True)

shard_path = paths.pixel_catalog_file(partition_dir, source_order, source_pix)

final_df = margin_data.drop(
columns=[
"margin_check",
"margin_pixel",
]
)
shard_path = paths.pixel_catalog_file(partition_dir, source_pixel.order, source_pixel.pixel)

rename_columns = {
PartitionInfo.METADATA_ORDER_COLUMN_NAME: f"margin_{PartitionInfo.METADATA_ORDER_COLUMN_NAME}",
PartitionInfo.METADATA_DIR_COLUMN_NAME: f"margin_{PartitionInfo.METADATA_DIR_COLUMN_NAME}",
PartitionInfo.METADATA_PIXEL_COLUMN_NAME: f"margin_{PartitionInfo.METADATA_PIXEL_COLUMN_NAME}",
"partition_order": PartitionInfo.METADATA_ORDER_COLUMN_NAME,
"partition_pixel": PartitionInfo.METADATA_PIXEL_COLUMN_NAME,
}

final_df.rename(columns=rename_columns, inplace=True)
margin_data = margin_data.rename(columns=rename_columns)

dir_column = np.floor_divide(final_df[PartitionInfo.METADATA_PIXEL_COLUMN_NAME].values, 10000) * 10000
margin_data[PartitionInfo.METADATA_ORDER_COLUMN_NAME] = pixel.order
margin_data[PartitionInfo.METADATA_DIR_COLUMN_NAME] = pixel.dir
margin_data[PartitionInfo.METADATA_PIXEL_COLUMN_NAME] = pixel.pixel

final_df[PartitionInfo.METADATA_DIR_COLUMN_NAME] = dir_column

final_df = final_df.astype(
margin_data = margin_data.astype(
{
PartitionInfo.METADATA_ORDER_COLUMN_NAME: np.uint8,
PartitionInfo.METADATA_DIR_COLUMN_NAME: np.uint64,
PartitionInfo.METADATA_PIXEL_COLUMN_NAME: np.uint64,
}
)
final_df = final_df.set_index(HIPSCAT_ID_COLUMN).sort_index()

final_df.to_parquet(shard_path)
margin_data = margin_data.sort_index()

del data, margin_data, final_df
margin_data.to_parquet(shard_path)


def reduce_margin_shards(
Expand Down
4 changes: 2 additions & 2 deletions src/hipscat_import/margin_cache/margin_cache_resume_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _gather_plan(self, args):
self.margin_pair_file = file_io.append_paths_to_pointer(self.tmp_path, self.MARGIN_PAIR_FILE)
if not file_io.does_file_or_directory_exist(self.margin_pair_file):
margin_pairs = _find_partition_margin_pixel_pairs(self.combined_pixels, args.margin_order)
margin_pairs.to_csv(self.margin_pair_file)
margin_pairs.to_csv(self.margin_pair_file, index=False)
step_progress.update(1)

file_io.make_directory(
Expand Down Expand Up @@ -167,5 +167,5 @@ def _find_partition_margin_pixel_pairs(combined_pixels, margin_order):
margin_pairs_df = pd.DataFrame(
zip(norders, part_pix, margin_pix),
columns=["partition_order", "partition_pixel", "margin_pixel"],
)
).sort_values("margin_pixel")
return margin_pairs_df
30 changes: 2 additions & 28 deletions tests/hipscat_import/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import re
from pathlib import Path

import hipscat.pixel_math.healpix_shim as hp
import numpy as np
import numpy.testing as npt
import pandas as pd
Expand Down Expand Up @@ -170,67 +169,42 @@ def resume_dir(test_data_dir):
def basic_data_shard_df():
ras = np.arange(0.0, 360.0)
dec = np.full(360, 0.0)
ppix = np.full(360, 21)
porder = np.full(360, 1)
norder = np.full(360, 1)
npix = np.full(360, 0)
hipscat_indexes = pixel_math.compute_hipscat_id(ras, dec)

test_df = pd.DataFrame(
data=zip(hipscat_indexes, ras, dec, ppix, porder, norder, npix),
data=zip(hipscat_indexes, ras, dec, norder, npix),
columns=[
"_hipscat_index",
"weird_ra",
"weird_dec",
"partition_pixel",
"partition_order",
"Norder",
"Npix",
],
)

test_df["margin_pixel"] = hp.ang2pix(
2**3,
test_df["weird_ra"].values,
test_df["weird_dec"].values,
lonlat=True,
nest=True,
)

return test_df


@pytest.fixture
def polar_data_shard_df():
ras = np.arange(0.0, 360.0)
dec = np.full(360, 89.9)
ppix = np.full(360, 15)
porder = np.full(360, 2)
norder = np.full(360, 2)
npix = np.full(360, 0)
hipscat_indexes = pixel_math.compute_hipscat_id(ras, dec)

test_df = pd.DataFrame(
data=zip(hipscat_indexes, ras, dec, ppix, porder, norder, npix),
data=zip(hipscat_indexes, ras, dec, norder, npix),
columns=[
"_hipscat_index",
"weird_ra",
"weird_dec",
"partition_pixel",
"partition_order",
"Norder",
"Npix",
],
)

test_df["margin_pixel"] = hp.ang2pix(
2**3,
test_df["weird_ra"].values,
test_df["weird_dec"].values,
lonlat=True,
nest=True,
)

return test_df


Expand Down
Loading

0 comments on commit 33b89b2

Please sign in to comment.