Skip to content

Commit

Permalink
test: Cleanup tmp dirs in test dir (#2917)
Browse files Browse the repository at this point in the history
Signed-off-by: JiaWei Jiang <[email protected]>
  • Loading branch information
JiangJiaWei1103 authored Nov 12, 2024
1 parent 3f0ab84 commit f09c536
Showing 1 changed file with 49 additions and 26 deletions.
75 changes: 49 additions & 26 deletions tests/flytekit/unit/types/directory/test_dir.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,75 @@
import tempfile
from pathlib import Path
from typing import Optional

import pytest
import flytekit
from flytekit import task, workflow
from flytekit.types.directory import FlyteDirectory


def test_src_path_with_different_types() -> None:
N_FILES = 3
N_FILES_PER_DIR = 3

@pytest.fixture
def local_tmp_dirs():
# Create a source directory
src_dir = tempfile.TemporaryDirectory()
for file_idx in range(N_FILES_PER_DIR):
with open(Path(src_dir.name) / f"{file_idx}.txt", "w") as f:
f.write(str(file_idx))

# Create an empty directory as the destination
dst_dir = tempfile.TemporaryDirectory()

yield src_dir.name, dst_dir.name

# Cleanup
src_dir.cleanup()
dst_dir.cleanup()

@task
def write_fidx_task(
use_str_src_path: bool, remote_dir: Optional[str] = None
) -> FlyteDirectory:
"""Write file indices to text files in a source path."""
source_path = Path(flytekit.current_context().working_directory) / "txt_files"
source_path.mkdir(exist_ok=True)

for file_idx in range(N_FILES):
file_path = source_path / f"{file_idx}.txt"
with file_path.open(mode="w") as f:
f.write(str(file_idx))
def test_src_path_with_different_types(local_tmp_dirs) -> None:

if use_str_src_path:
source_path = str(source_path)
@task
def create_flytedir(
source_path: str,
use_pathlike_src_path: bool,
remote_dir: Optional[str] = None
) -> FlyteDirectory:
if use_pathlike_src_path:
source_path = Path(source_path)
fd = FlyteDirectory(path=source_path, remote_directory=remote_dir)

return fd

@workflow
def wf(use_str_src_path: bool, remote_dir: Optional[str] = None) -> FlyteDirectory:
return write_fidx_task(use_str_src_path=use_str_src_path, remote_dir=remote_dir)
def wf(
source_path: str,
use_pathlike_src_path: bool,
remote_dir: Optional[str] = None
) -> FlyteDirectory:
return create_flytedir(
source_path=source_path, use_pathlike_src_path=use_pathlike_src_path, remote_dir=remote_dir
)

def _verify_files(fd: FlyteDirectory) -> None:
for file_idx in range(N_FILES):
for file_idx in range(N_FILES_PER_DIR):
with open(fd / f"{file_idx}.txt", "r") as f:
assert f.read() == str(file_idx)


source_path, remote_dir = local_tmp_dirs

# Source path is of type str
ff_1 = wf(use_str_src_path=True, remote_dir=None)
_verify_files(ff_1)
fd_1 = wf(source_path=source_path, use_pathlike_src_path=False, remote_dir=None)
_verify_files(fd_1)

ff_2 = wf(use_str_src_path=True, remote_dir="./my_txt_files")
_verify_files(ff_2)
fd_2 = wf(source_path=source_path, use_pathlike_src_path=False, remote_dir=remote_dir)
_verify_files(fd_2)

# Source path is of type pathlib.PosixPath
ff_3 = wf(use_str_src_path=False, remote_dir=None)
_verify_files(ff_3)
fd_3 = wf(source_path=source_path, use_pathlike_src_path=True, remote_dir=None)
_verify_files(fd_3)

ff_4 = wf(use_str_src_path=False, remote_dir="./my_txt_files2")
_verify_files(ff_4)
fd_4 = wf(source_path=source_path, use_pathlike_src_path=True, remote_dir=remote_dir)
_verify_files(fd_4)

0 comments on commit f09c536

Please sign in to comment.