diff --git a/src/hipscat_import/pipeline_resume_plan.py b/src/hipscat_import/pipeline_resume_plan.py index f223b3c4..3511d1da 100644 --- a/src/hipscat_import/pipeline_resume_plan.py +++ b/src/hipscat_import/pipeline_resume_plan.py @@ -172,30 +172,30 @@ def check_original_input_paths(self, input_paths): Raises: ValueError: if the retrieved file set differs from `input_paths`. """ - unique_file_paths = set(input_paths) + input_paths = set(input_paths) + input_paths = [str(p) for p in input_paths] + input_paths.sort() original_input_paths = [] - file_path = file_io.append_paths_to_pointer(self.tmp_path, self.ORIGINAL_INPUT_PATHS) + log_file_path = file_io.append_paths_to_pointer(self.tmp_path, self.ORIGINAL_INPUT_PATHS) try: - with open(file_path, "r", encoding="utf-8") as file_handle: + with open(log_file_path, "r", encoding="utf-8") as file_handle: contents = file_handle.readlines() contents = [path.strip() for path in contents] - original_input_paths = set(contents) + original_input_paths = list(set(contents)) + original_input_paths.sort() except FileNotFoundError: pass if len(original_input_paths) == 0: - file_path = file_io.append_paths_to_pointer(self.tmp_path, self.ORIGINAL_INPUT_PATHS) - with open(file_path, "w", encoding="utf-8") as file_handle: + with open(log_file_path, "w", encoding="utf-8") as file_handle: for path in input_paths: file_handle.write(f"{path}\n") else: - if original_input_paths != unique_file_paths: + if original_input_paths != input_paths: raise ValueError("Different file set from resumed pipeline execution.") - input_paths = list(unique_file_paths) - input_paths.sort() return input_paths diff --git a/tests/hipscat_import/test_pipeline_resume_plan.py b/tests/hipscat_import/test_pipeline_resume_plan.py index c878d827..0f806660 100644 --- a/tests/hipscat_import/test_pipeline_resume_plan.py +++ b/tests/hipscat_import/test_pipeline_resume_plan.py @@ -3,6 +3,7 @@ import os from pathlib import Path +import numpy.testing as npt import pytest from hipscat_import.pipeline_resume_plan import PipelineResumePlan @@ -135,3 +136,18 @@ def test_formatted_stage_name(): formatted = PipelineResumePlan.get_formatted_stage_name("very long stage name") assert formatted == "Very long stage name" + + +def test_check_original_input_paths(tmp_path, mixed_schema_csv_dir): + plan = PipelineResumePlan(tmp_path=tmp_path, progress_bar=False, resume=False) + + input_file_list = [ + Path(mixed_schema_csv_dir) / "input_01.csv", + Path(mixed_schema_csv_dir) / "input_02.csv", + ] + + checked_files = plan.check_original_input_paths(input_file_list) + + round_trip_files = plan.check_original_input_paths(checked_files) + + npt.assert_array_equal(checked_files, round_trip_files)