Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keep track of expected number of margin rows written. #418

Merged
merged 4 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/hats_import/margin_cache/margin_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@
)
resume_plan.wait_for_mapping(futures)

with resume_plan.print_progress(total=1, stage_name="Binning") as step_progress:
total_rows = resume_plan.get_mapping_total()
if not total_rows:
raise ValueError("Margin cache contains no rows. Increase margin size and re-run.")
delucchi-cmu marked this conversation as resolved.
Show resolved Hide resolved

if not resume_plan.is_reducing_done():
futures = []
for reducing_key, pix in resume_plan.get_remaining_reduce_keys():
Expand All @@ -57,7 +62,11 @@
resume_plan.wait_for_reducing(futures)

with resume_plan.print_progress(total=4, stage_name="Finishing") as step_progress:
total_rows = parquet_metadata.write_parquet_metadata(args.catalog_path)
metadata_total_rows = parquet_metadata.write_parquet_metadata(args.catalog_path)
if metadata_total_rows != total_rows:
raise ValueError(

Check warning on line 67 in src/hats_import/margin_cache/margin_cache.py

View check run for this annotation

Codecov / codecov/patch

src/hats_import/margin_cache/margin_cache.py#L67

Added line #L67 was not covered by tests
f"Wrote unexpected number of rows ({total_rows} expected, {metadata_total_rows} written)"
)
step_progress.update(1)
metadata_path = paths.get_parquet_metadata_pointer(args.catalog_path)
partition_info = PartitionInfo.read_from_file(metadata_path)
Expand Down
9 changes: 6 additions & 3 deletions src/hats_import/margin_cache/margin_cache_map_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,13 @@ def map_pixel_shards(
# For every possible output pixel, find the full margin_order pixel filter list,
# perform the filter, and pass along to helper method to compute fine filter
# and write out shard file.
num_rows = 0
for partition_key, data_filter in margin_pixel_filter.groupby(["partition_order", "partition_pixel"]):
data_filter = np.unique(data_filter["filter_value"]).tolist()
pixel = HealpixPixel(partition_key[0], partition_key[1])

filtered_data = data.iloc[data_filter]
_to_pixel_shard(
num_rows += _to_pixel_shard(
filtered_data=filtered_data,
pixel=pixel,
margin_threshold=margin_threshold,
Expand All @@ -71,7 +72,7 @@ def map_pixel_shards(
fine_filtering=fine_filtering,
)

MarginCachePlan.mapping_key_done(output_path, mapping_key)
MarginCachePlan.mapping_key_done(output_path, mapping_key, num_rows)
except Exception as exception: # pylint: disable=broad-exception-caught
print_task_failure(f"Failed MAPPING stage for pixel: {mapping_key}", exception)
raise exception
Expand Down Expand Up @@ -101,7 +102,8 @@ def _to_pixel_shard(
else:
margin_data = filtered_data

if len(margin_data):
num_rows = len(margin_data)
if num_rows:
# generate a file name for our margin shard, that uses both sets of Norder/Npix
partition_dir = get_pixel_cache_directory(output_path, pixel)
shard_dir = paths.pixel_directory(partition_dir, source_pixel.order, source_pixel.pixel)
Expand Down Expand Up @@ -132,6 +134,7 @@ def _to_pixel_shard(
margin_data = margin_data.sort_index()

margin_data.to_parquet(shard_path.path, filesystem=shard_path.fs)
return num_rows


def reduce_margin_shards(
Expand Down
28 changes: 26 additions & 2 deletions src/hats_import/margin_cache/margin_cache_resume_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
MAPPING_STAGE = "mapping"
REDUCING_STAGE = "reducing"
MARGIN_PAIR_FILE = "margin_pair.csv"
MAPPING_TOTAL_FILE = "mapping_total"

def __init__(self, args: MarginCacheArguments):
if not args.tmp_path: # pragma: no cover (not reachable, but required for mypy)
Expand Down Expand Up @@ -64,14 +65,14 @@
step_progress.update(1)

@classmethod
def mapping_key_done(cls, tmp_path, mapping_key: str):
def mapping_key_done(cls, tmp_path, mapping_key: str, num_rows: int):
"""Mark a single mapping task as done

Args:
tmp_path (str): where to write intermediate resume files.
mapping_key (str): unique string for each mapping task (e.g. "map_1_24")
"""
cls.touch_key_done_file(tmp_path, cls.MAPPING_STAGE, mapping_key)
cls.write_marker_file(tmp_path, cls.MAPPING_STAGE, mapping_key, str(num_rows))

def wait_for_mapping(self, futures):
"""Wait for mapping stage futures to complete."""
Expand All @@ -87,6 +88,22 @@
"""Are there sources left to count?"""
return self.done_file_exists(self.MAPPING_STAGE)

def get_mapping_total(self) -> int:
"""Find the total number of rows, based on the intermediate markers."""
if not self.is_mapping_done():
raise ValueError("mapping stage is not done yet.")

Check warning on line 94 in src/hats_import/margin_cache/margin_cache_resume_plan.py

View check run for this annotation

Codecov / codecov/patch

src/hats_import/margin_cache/margin_cache_resume_plan.py#L94

Added line #L94 was not covered by tests

total_marker_file = file_io.append_paths_to_pointer(self.tmp_path, self.MAPPING_TOTAL_FILE)

if file_io.does_file_or_directory_exist(total_marker_file):
marker_value = file_io.load_text_file(total_marker_file)
return _marker_value_to_int(marker_value)

markers = self.read_markers(self.MAPPING_STAGE)
total_marker_value = sum(_marker_value_to_int(value) for value in markers.values())
file_io.write_string_to_file(total_marker_file, str(total_marker_value))
return total_marker_value

def get_remaining_map_keys(self):
"""Fetch a tuple for each pixel/partition left to map."""
map_keys = set(self.read_done_keys(self.MAPPING_STAGE))
Expand Down Expand Up @@ -163,3 +180,10 @@
columns=["partition_order", "partition_pixel", "margin_pixel"],
).sort_values("margin_pixel")
return margin_pairs_df


def _marker_value_to_int(marker_value: List[str]) -> int:
delucchi-cmu marked this conversation as resolved.
Show resolved Hide resolved
"""Convenience method to parse the contents of a marker file."""
if len(marker_value) != 1:
raise ValueError("Marker file should contain only one integer value.")

Check warning on line 188 in src/hats_import/margin_cache/margin_cache_resume_plan.py

View check run for this annotation

Codecov / codecov/patch

src/hats_import/margin_cache/margin_cache_resume_plan.py#L188

Added line #L188 was not covered by tests
return int(marker_value[0])
36 changes: 35 additions & 1 deletion src/hats_import/pipeline_resume_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,46 @@ def touch_key_done_file(cls, tmp_path, stage_name, key):
"""Touch (create) a done file for a single key, within a pipeline stage.

Args:
stage_name(str): name of the stage (e.g. mapping, reducing)
tmp_path: where to write pipeline intermediate files
stage_name(str): name of the stage (e.g. mapping, reducing)
key (str): unique string for each task (e.g. "map_57")
"""
Path(file_io.append_paths_to_pointer(tmp_path, stage_name, f"{key}_done")).touch()

@classmethod
def write_marker_file(cls, tmp_path, stage_name, key, value):
"""Create a marker file for a single key, within a pipeline stage.

Similar to a "done" file, but contains some value inside the file to be read later.

Args:
tmp_path: where to write pipeline intermediate files
stage_name(str): name of the stage (e.g. mapping, reducing)
key (str): unique string for each task (e.g. "map_57")
value (str): value for the marker.
"""
file_io.write_string_to_file(
file_io.append_paths_to_pointer(tmp_path, stage_name, f"{key}_done"), value
)

def read_markers(self, stage_name):
delucchi-cmu marked this conversation as resolved.
Show resolved Hide resolved
"""Inspect the stage's directory of marker files, fetching the key value pairs
from marker file names and contents.

Args:
stage_name(str): name of the stage (e.g. mapping, reducing)
Return:
List[str] - all keys found in done directory
delucchi-cmu marked this conversation as resolved.
Show resolved Hide resolved
"""
prefix = file_io.append_paths_to_pointer(self.tmp_path, stage_name)
result = {}
result_files = file_io.find_files_matching_path(prefix, "*_done")
for file_path in result_files:
match = re.match(r"(.*)_done", str(file_path.name))
key = match.group(1)
result[key] = file_io.load_text_file(file_path)
return result

def read_done_keys(self, stage_name):
"""Inspect the stage's directory of done files, fetching the keys from done file names.

Expand Down
1 change: 1 addition & 0 deletions tests/data/markers/mapping/map_001_done
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
45
1 change: 1 addition & 0 deletions tests/data/markers/mapping/map_002_done
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
zippy
16 changes: 16 additions & 0 deletions tests/hats_import/margin_cache/test_margin_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,19 @@ def test_margin_cache_gen_negative_pixels(small_sky_source_catalog, tmp_path, da
negative_data = pd.read_parquet(negative_test_file)

assert len(negative_data) > 0


@pytest.mark.dask(timeout=150)
def test_margin_too_small(small_sky_object_catalog, tmp_path, dask_client):
"""Test that margin cache generation works end to end."""
args = MarginCacheArguments(
margin_threshold=10.0,
input_catalog_path=small_sky_object_catalog,
output_path=tmp_path,
output_artifact_name="catalog_cache",
margin_order=8,
progress_bar=False,
)

with pytest.raises(ValueError, match="Margin cache contains no rows"):
delucchi-cmu marked this conversation as resolved.
Show resolved Hide resolved
mc.generate_margin_cache(args, dask_client)
19 changes: 19 additions & 0 deletions tests/hats_import/margin_cache/test_margin_cache_resume_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,25 @@ def test_some_reducing_task_failures(small_sky_margin_args, dask_client):
plan.wait_for_reducing(futures)


def test_mapping_total(small_sky_margin_args):
plan = MarginCachePlan(small_sky_margin_args)

MarginCachePlan.mapping_key_done(plan.tmp_path, "map_001", 45)
MarginCachePlan.mapping_key_done(plan.tmp_path, "map_002", 9)
plan.touch_stage_done_file(MarginCachePlan.MAPPING_STAGE)

markers = plan.read_markers("mapping")
assert markers == {"map_001": ["45"], "map_002": ["9"]}

mapping_total = plan.get_mapping_total()
assert mapping_total == 54

# We'll just return the previously-computed total
MarginCachePlan.mapping_key_done(plan.tmp_path, "map_002", 10)
mapping_total = plan.get_mapping_total()
assert mapping_total == 54


def test_partition_margin_pixel_pairs(small_sky_source_catalog):
"""Ensure partition_margin_pixel_pairs can generate main partition pixels."""
source_catalog = Catalog.read_hats(small_sky_source_catalog)
Expand Down
6 changes: 6 additions & 0 deletions tests/hats_import/test_pipeline_resume_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,9 @@ def test_check_original_input_paths(tmp_path, mixed_schema_csv_dir):
round_trip_files = plan.check_original_input_paths(checked_files)

npt.assert_array_equal(checked_files, round_trip_files)


def test_read_markers(test_data_dir):
plan = PipelineResumePlan(tmp_path=test_data_dir / "markers", progress_bar=False)
markers = plan.read_markers("mapping")
assert markers == {"map_001": ["45"], "map_002": ["zippy"]}