Skip to content

Commit

Permalink
Spruce up the way we check for original input files. (#314)
Browse files Browse the repository at this point in the history
* Spruce up the way we check for original input files.

* Improve variable names and usage.
  • Loading branch information
delucchi-cmu committed May 22, 2024
1 parent 8b19bc2 commit 37ce248
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 9 deletions.
18 changes: 9 additions & 9 deletions src/hipscat_import/pipeline_resume_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,30 +172,30 @@ def check_original_input_paths(self, input_paths):
Raises:
ValueError: if the retrieved file set differs from `input_paths`.
"""
unique_file_paths = set(input_paths)
input_paths = set(input_paths)
input_paths = [str(p) for p in input_paths]
input_paths.sort()

original_input_paths = []

file_path = file_io.append_paths_to_pointer(self.tmp_path, self.ORIGINAL_INPUT_PATHS)
log_file_path = file_io.append_paths_to_pointer(self.tmp_path, self.ORIGINAL_INPUT_PATHS)
try:
with open(file_path, "r", encoding="utf-8") as file_handle:
with open(log_file_path, "r", encoding="utf-8") as file_handle:
contents = file_handle.readlines()
contents = [path.strip() for path in contents]
original_input_paths = set(contents)
original_input_paths = list(set(contents))
original_input_paths.sort()
except FileNotFoundError:
pass

if len(original_input_paths) == 0:
file_path = file_io.append_paths_to_pointer(self.tmp_path, self.ORIGINAL_INPUT_PATHS)
with open(file_path, "w", encoding="utf-8") as file_handle:
with open(log_file_path, "w", encoding="utf-8") as file_handle:
for path in input_paths:
file_handle.write(f"{path}\n")
else:
if original_input_paths != unique_file_paths:
if original_input_paths != input_paths:
raise ValueError("Different file set from resumed pipeline execution.")

input_paths = list(unique_file_paths)
input_paths.sort()
return input_paths


Expand Down
16 changes: 16 additions & 0 deletions tests/hipscat_import/test_pipeline_resume_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
from pathlib import Path

import numpy.testing as npt
import pytest

from hipscat_import.pipeline_resume_plan import PipelineResumePlan
Expand Down Expand Up @@ -135,3 +136,18 @@ def test_formatted_stage_name():

formatted = PipelineResumePlan.get_formatted_stage_name("very long stage name")
assert formatted == "Very long stage name"


def test_check_original_input_paths(tmp_path, mixed_schema_csv_dir):
plan = PipelineResumePlan(tmp_path=tmp_path, progress_bar=False, resume=False)

input_file_list = [
Path(mixed_schema_csv_dir) / "input_01.csv",
Path(mixed_schema_csv_dir) / "input_02.csv",
]

checked_files = plan.check_original_input_paths(input_file_list)

round_trip_files = plan.check_original_input_paths(checked_files)

npt.assert_array_equal(checked_files, round_trip_files)

0 comments on commit 37ce248

Please sign in to comment.