Skip to content

Commit

Permalink
But actually do the TODO this time.
Browse files Browse the repository at this point in the history
  • Loading branch information
delucchi-cmu committed Jan 19, 2024
1 parent ee3ac36 commit d19f59e
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 25 deletions.
23 changes: 7 additions & 16 deletions src/hipscat_import/margin_cache/margin_cache_map_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from hipscat import pixel_math
from hipscat.catalog.partition_info import PartitionInfo
from hipscat.io import file_io, paths
from hipscat.pixel_math.healpix_pixel import HealpixPixel

# pylint: disable=too-many-locals,too-many-arguments
from hipscat_import.pipeline_resume_plan import get_pixel_cache_directory


def map_pixel_shards(
Expand Down Expand Up @@ -53,10 +54,8 @@ def _to_pixel_shard(data, margin_threshold, output_path, ra_column, dec_column):
margin_data = data.loc[data["margin_check"] == True]

if len(margin_data):
# TODO: this should be a utility function in `hipscat`
# that properly handles the hive formatting
# generate a file name for our margin shard
partition_dir = _get_partition_directory(output_path, order, pix)
# generate a file name for our margin shard, that uses both sets of Norder/Npix
partition_dir = get_pixel_cache_directory(output_path, HealpixPixel(order, pix))
shard_dir = paths.pixel_directory(partition_dir, source_order, source_pix)

file_io.make_directory(shard_dir, exist_ok=True)
Expand Down Expand Up @@ -97,23 +96,15 @@ def _to_pixel_shard(data, margin_threshold, output_path, ra_column, dec_column):
del data, margin_data, final_df


def _get_partition_directory(path, order, pix):
"""Get the directory where a partition pixel's margin shards live"""
partition_file = paths.pixel_catalog_file(path, order, pix)

# removes the '.parquet' and adds a slash
partition_dir = f"{partition_file[:-8]}/"

return partition_dir


def reduce_margin_shards(output_path, partition_order, partition_pixel):
"""Reduce all partition pixel directories into a single file"""
shard_dir = _get_partition_directory(output_path, partition_order, partition_pixel)
shard_dir = get_pixel_cache_directory(output_path, HealpixPixel(partition_order, partition_pixel))

if file_io.does_file_or_directory_exist(shard_dir):
data = ds.dataset(shard_dir, format="parquet")
full_df = data.to_table().to_pandas()
margin_cache_dir = paths.pixel_directory(output_path, partition_order, partition_pixel)
file_io.make_directory(margin_cache_dir, exist_ok=True)

if len(full_df):
margin_cache_file_path = paths.pixel_catalog_file(output_path, partition_order, partition_pixel)
Expand Down
21 changes: 12 additions & 9 deletions tests/hipscat_import/margin_cache/test_margin_cache_map_reduce.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import os

import pandas as pd
import pytest
from hipscat.io import file_io, paths
from hipscat.io import paths
from hipscat.pixel_math.healpix_pixel import HealpixPixel

from hipscat_import.margin_cache import margin_cache_map_reduce
from hipscat_import.pipeline_resume_plan import get_pixel_cache_directory

keep_cols = ["weird_ra", "weird_dec"]

Expand Down Expand Up @@ -39,9 +43,9 @@ def test_to_pixel_shard_equator(tmp_path, basic_data_shard_df):
dec_column="weird_dec",
)

path = file_io.append_paths_to_pointer(tmp_path, "Norder=1/Dir=0/Npix=21/Norder=1/Dir=0/Npix=0.parquet")
path = os.path.join(tmp_path, "order_1", "dir_0", "pixel_21", "Norder=1", "Dir=0", "Npix=0.parquet")

assert file_io.does_file_or_directory_exist(path)
assert os.path.exists(path)

validate_result_dataframe(path, 2)

Expand All @@ -56,25 +60,23 @@ def test_to_pixel_shard_polar(tmp_path, polar_data_shard_df):
dec_column="weird_dec",
)

path = file_io.append_paths_to_pointer(tmp_path, "Norder=2/Dir=0/Npix=15/Norder=2/Dir=0/Npix=0.parquet")
path = os.path.join(tmp_path, "order_2", "dir_0", "pixel_15", "Norder=2", "Dir=0", "Npix=0.parquet")

assert file_io.does_file_or_directory_exist(path)
assert os.path.exists(path)

validate_result_dataframe(path, 360)


@pytest.mark.dask
def test_reduce_margin_shards(tmp_path, basic_data_shard_df):
partition_dir = margin_cache_map_reduce._get_partition_directory(tmp_path, 1, 21)
partition_dir = get_pixel_cache_directory(tmp_path, HealpixPixel(1, 21))
shard_dir = paths.pixel_directory(partition_dir, 1, 21)

file_io.make_directory(shard_dir, exist_ok=True)
os.makedirs(shard_dir)

first_shard_path = paths.pixel_catalog_file(partition_dir, 1, 0)
second_shard_path = paths.pixel_catalog_file(partition_dir, 1, 1)

print(first_shard_path)

shard_df = basic_data_shard_df.drop(columns=["partition_order", "partition_pixel", "margin_pixel"])

shard_df.to_parquet(first_shard_path)
Expand All @@ -85,3 +87,4 @@ def test_reduce_margin_shards(tmp_path, basic_data_shard_df):
result_path = paths.pixel_catalog_file(tmp_path, 1, 21)

validate_result_dataframe(result_path, 720)
assert not os.path.exists(shard_dir)

0 comments on commit d19f59e

Please sign in to comment.