diff --git a/tests/flytekit/unit/types/directory/test_dir.py b/tests/flytekit/unit/types/directory/test_dir.py index 285162d1e9..9d311cedae 100644 --- a/tests/flytekit/unit/types/directory/test_dir.py +++ b/tests/flytekit/unit/types/directory/test_dir.py @@ -1,52 +1,75 @@ +import tempfile from pathlib import Path from typing import Optional +import pytest import flytekit from flytekit import task, workflow from flytekit.types.directory import FlyteDirectory -def test_src_path_with_different_types() -> None: - N_FILES = 3 +N_FILES_PER_DIR = 3 + +@pytest.fixture +def local_tmp_dirs(): + # Create a source directory + src_dir = tempfile.TemporaryDirectory() + for file_idx in range(N_FILES_PER_DIR): + with open(Path(src_dir.name) / f"{file_idx}.txt", "w") as f: + f.write(str(file_idx)) + + # Create an empty directory as the destination + dst_dir = tempfile.TemporaryDirectory() + + yield src_dir.name, dst_dir.name + + # Cleanup + src_dir.cleanup() + dst_dir.cleanup() - @task - def write_fidx_task( - use_str_src_path: bool, remote_dir: Optional[str] = None - ) -> FlyteDirectory: - """Write file indices to text files in a source path.""" - source_path = Path(flytekit.current_context().working_directory) / "txt_files" - source_path.mkdir(exist_ok=True) - for file_idx in range(N_FILES): - file_path = source_path / f"{file_idx}.txt" - with file_path.open(mode="w") as f: - f.write(str(file_idx)) +def test_src_path_with_different_types(local_tmp_dirs) -> None: - if use_str_src_path: - source_path = str(source_path) + @task + def create_flytedir( + source_path: str, + use_pathlike_src_path: bool, + remote_dir: Optional[str] = None + ) -> FlyteDirectory: + if use_pathlike_src_path: + source_path = Path(source_path) fd = FlyteDirectory(path=source_path, remote_directory=remote_dir) return fd @workflow - def wf(use_str_src_path: bool, remote_dir: Optional[str] = None) -> FlyteDirectory: - return write_fidx_task(use_str_src_path=use_str_src_path, remote_dir=remote_dir) + def wf( + source_path: str, + use_pathlike_src_path: bool, + remote_dir: Optional[str] = None + ) -> FlyteDirectory: + return create_flytedir( + source_path=source_path, use_pathlike_src_path=use_pathlike_src_path, remote_dir=remote_dir + ) def _verify_files(fd: FlyteDirectory) -> None: - for file_idx in range(N_FILES): + for file_idx in range(N_FILES_PER_DIR): with open(fd / f"{file_idx}.txt", "r") as f: assert f.read() == str(file_idx) + + source_path, remote_dir = local_tmp_dirs + # Source path is of type str - ff_1 = wf(use_str_src_path=True, remote_dir=None) - _verify_files(ff_1) + fd_1 = wf(source_path=source_path, use_pathlike_src_path=False, remote_dir=None) + _verify_files(fd_1) - ff_2 = wf(use_str_src_path=True, remote_dir="./my_txt_files") - _verify_files(ff_2) + fd_2 = wf(source_path=source_path, use_pathlike_src_path=False, remote_dir=remote_dir) + _verify_files(fd_2) # Source path is of type pathlib.PosixPath - ff_3 = wf(use_str_src_path=False, remote_dir=None) - _verify_files(ff_3) + fd_3 = wf(source_path=source_path, use_pathlike_src_path=True, remote_dir=None) + _verify_files(fd_3) - ff_4 = wf(use_str_src_path=False, remote_dir="./my_txt_files2") - _verify_files(ff_4) + fd_4 = wf(source_path=source_path, use_pathlike_src_path=True, remote_dir=remote_dir) + _verify_files(fd_4)