From 48dc9acc3765954af0a3a46bbbc7b6f0650370c0 Mon Sep 17 00:00:00 2001 From: Melissa DeLucchi Date: Fri, 17 May 2024 15:04:53 -0400 Subject: [PATCH] Spruce up the way we check for original input files. --- src/hipscat_import/pipeline_resume_plan.py | 3 ++- .../hipscat_import/test_pipeline_resume_plan.py | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/src/hipscat_import/pipeline_resume_plan.py b/src/hipscat_import/pipeline_resume_plan.py index f223b3c4..e97a3acf 100644 --- a/src/hipscat_import/pipeline_resume_plan.py +++ b/src/hipscat_import/pipeline_resume_plan.py @@ -173,6 +173,7 @@ def check_original_input_paths(self, input_paths): ValueError: if the retrieved file set differs from `input_paths`. """ unique_file_paths = set(input_paths) + unique_file_paths = [str(p) for p in unique_file_paths] original_input_paths = [] @@ -181,7 +182,7 @@ def check_original_input_paths(self, input_paths): with open(file_path, "r", encoding="utf-8") as file_handle: contents = file_handle.readlines() contents = [path.strip() for path in contents] - original_input_paths = set(contents) + original_input_paths = list(set(contents)) except FileNotFoundError: pass diff --git a/tests/hipscat_import/test_pipeline_resume_plan.py b/tests/hipscat_import/test_pipeline_resume_plan.py index c878d827..0f806660 100644 --- a/tests/hipscat_import/test_pipeline_resume_plan.py +++ b/tests/hipscat_import/test_pipeline_resume_plan.py @@ -3,6 +3,7 @@ import os from pathlib import Path +import numpy.testing as npt import pytest from hipscat_import.pipeline_resume_plan import PipelineResumePlan @@ -135,3 +136,18 @@ def test_formatted_stage_name(): formatted = PipelineResumePlan.get_formatted_stage_name("very long stage name") assert formatted == "Very long stage name" + + +def test_check_original_input_paths(tmp_path, mixed_schema_csv_dir): + plan = PipelineResumePlan(tmp_path=tmp_path, progress_bar=False, resume=False) + + input_file_list = [ + Path(mixed_schema_csv_dir) / "input_01.csv", + Path(mixed_schema_csv_dir) / "input_02.csv", + ] + + checked_files = plan.check_original_input_paths(input_file_list) + + round_trip_files = plan.check_original_input_paths(checked_files) + + npt.assert_array_equal(checked_files, round_trip_files)