From b8ef229b3705261687cea84465ce80c6a7f515f6 Mon Sep 17 00:00:00 2001 From: Felix Dangel <48687646+f-dangel@users.noreply.github.com> Date: Wed, 11 Sep 2024 15:32:08 -0400 Subject: [PATCH] [TEST] Add integration test for training script (#11) --- pyproject.toml | 1 + test/example/__init__.py | 1 + test/example/test_train.py | 23 +++++++++++++++++++++++ test/test___init__.py | 34 ---------------------------------- test/utils.py | 27 +++++++++++++++++++++++++++ wandb_preempt/__init__.py | 10 ---------- 6 files changed, 52 insertions(+), 44 deletions(-) create mode 100644 test/example/__init__.py create mode 100644 test/example/test_train.py delete mode 100644 test/test___init__.py create mode 100644 test/utils.py diff --git a/pyproject.toml b/pyproject.toml index b5ff7b7..5f8fd66 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ test = [ "pytest", "pytest-cov", "pytest-optional-tests", + "torchvision", ] # Dependencies needed for linting (comma/line-separated) diff --git a/test/example/__init__.py b/test/example/__init__.py new file mode 100644 index 0000000..10c42d5 --- /dev/null +++ b/test/example/__init__.py @@ -0,0 +1 @@ +"""Contains tests of the example.""" diff --git a/test/example/test_train.py b/test/example/test_train.py new file mode 100644 index 0000000..bd880be --- /dev/null +++ b/test/example/test_train.py @@ -0,0 +1,23 @@ +"""Check that the training script is working.""" + +from os import environ, getenv, path +from test.utils import run_verbose + +HERE_DIR = path.dirname(path.abspath(__file__)) +TRAINING_SCRIPT = path.abspath(path.join(HERE_DIR, "..", "..", "example", "train.py")) + + +def test_training_script(): + """Execute the training script.""" + # Use wandb in offline mode. We do not want to upload the logs this test generates + ORIGINAL_WANDB_MODE = getenv("WANDB_MODE") + environ["WANDB_MODE"] = "offline" + + # Run the training script + run_verbose(["python", TRAINING_SCRIPT, "--epochs=3"]) + + # Restore the original value of WANDB_MODE + if ORIGINAL_WANDB_MODE is None: + environ.pop("WANDB_MODE") + else: + environ["WANDB_MODE"] = ORIGINAL_WANDB_MODE diff --git a/test/test___init__.py b/test/test___init__.py deleted file mode 100644 index 73cee8e..0000000 --- a/test/test___init__.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Tests for wandb_preempt/__init__.py.""" - -import time - -import pytest - -import wandb_preempt - -NAMES = ["world", "github"] -IDS = NAMES - - -# TODO Remove this function once we have a unit test that uses the checkpointer code -@pytest.mark.parametrize("name", NAMES, ids=IDS) -def test_hello(name: str): - """Test hello function. - - Args: - name: Name to greet. - """ - wandb_preempt.hello(name) - - -# TODO Remove this function once we have a unit test that uses the checkpointer code -@pytest.mark.expensive -@pytest.mark.parametrize("name", NAMES, ids=IDS) -def test_hello_expensive(name: str): - """Expensive test of hello. Will only be run on master/main and development. - - Args: - name: Name to greet. - """ - time.sleep(1) - wandb_preempt.hello(name) diff --git a/test/utils.py b/test/utils.py new file mode 100644 index 0000000..983804a --- /dev/null +++ b/test/utils.py @@ -0,0 +1,27 @@ +"""Utility functions for tests.""" + +from subprocess import CalledProcessError, CompletedProcess, run +from typing import List + + +def run_verbose(cmd: List[str]) -> CompletedProcess: + """Run a command and print stdout & stderr if it fails. + + Args: + cmd: The command to run. + + Returns: + CompletedProcess: The result of the command. + + Raises: + CalledProcessError: If the command fails. + """ + try: + job = run(cmd, capture_output=True, text=True, check=True) + print("STDOUT:", job.stdout) + print("STDERR:", job.stderr) + return job + except CalledProcessError as e: + print("STDOUT:", e.stdout) + print("STDERR:", e.stderr) + raise e diff --git a/wandb_preempt/__init__.py b/wandb_preempt/__init__.py index 26ffec2..d98d6f5 100644 --- a/wandb_preempt/__init__.py +++ b/wandb_preempt/__init__.py @@ -3,13 +3,3 @@ from wandb_preempt.checkpointer import Checkpointer __all__ = ["Checkpointer"] - - -# TODO Remove this function once we have a unit test that uses the checkpointer code -def hello(name): - """Say hello to a name. - - Args: - name (str): Name to say hello to. - """ - print(f"Hello, {name}")