Skip to content

Commit

Permalink
[TEST] Add integration test for training script (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel authored Sep 11, 2024
1 parent 9b2abfb commit b8ef229
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 44 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ test = [
"pytest",
"pytest-cov",
"pytest-optional-tests",
"torchvision",
]

# Dependencies needed for linting (comma/line-separated)
Expand Down
1 change: 1 addition & 0 deletions test/example/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Contains tests of the example."""
23 changes: 23 additions & 0 deletions test/example/test_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""Check that the training script is working."""

from os import environ, getenv, path
from test.utils import run_verbose

HERE_DIR = path.dirname(path.abspath(__file__))
TRAINING_SCRIPT = path.abspath(path.join(HERE_DIR, "..", "..", "example", "train.py"))


def test_training_script():
"""Execute the training script."""
# Use wandb in offline mode. We do not want to upload the logs this test generates
ORIGINAL_WANDB_MODE = getenv("WANDB_MODE")
environ["WANDB_MODE"] = "offline"

# Run the training script
run_verbose(["python", TRAINING_SCRIPT, "--epochs=3"])

# Restore the original value of WANDB_MODE
if ORIGINAL_WANDB_MODE is None:
environ.pop("WANDB_MODE")
else:
environ["WANDB_MODE"] = ORIGINAL_WANDB_MODE
34 changes: 0 additions & 34 deletions test/test___init__.py

This file was deleted.

27 changes: 27 additions & 0 deletions test/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Utility functions for tests."""

from subprocess import CalledProcessError, CompletedProcess, run
from typing import List


def run_verbose(cmd: List[str]) -> CompletedProcess:
"""Run a command and print stdout & stderr if it fails.
Args:
cmd: The command to run.
Returns:
CompletedProcess: The result of the command.
Raises:
CalledProcessError: If the command fails.
"""
try:
job = run(cmd, capture_output=True, text=True, check=True)
print("STDOUT:", job.stdout)
print("STDERR:", job.stderr)
return job
except CalledProcessError as e:
print("STDOUT:", e.stdout)
print("STDERR:", e.stderr)
raise e
10 changes: 0 additions & 10 deletions wandb_preempt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,3 @@
from wandb_preempt.checkpointer import Checkpointer

__all__ = ["Checkpointer"]


# TODO Remove this function once we have a unit test that uses the checkpointer code
def hello(name):
"""Say hello to a name.
Args:
name (str): Name to say hello to.
"""
print(f"Hello, {name}")

0 comments on commit b8ef229

Please sign in to comment.