diff --git a/example/launch.sh b/example/launch.sh index e65523d..3d4ce27 100644 --- a/example/launch.sh +++ b/example/launch.sh @@ -27,7 +27,7 @@ child="$!" function term_handler() { echo "$(date) ** Job $SLURM_JOB_NAME ($SLURM_JOB_ID) received SIGUSR1 at $(date) **" - # The CheckpointHandler will have written the PID of the Python process to a file + # The Checkpointer will have written the PID of the Python process to a file # so we can send it the SIGUSR1 signal PID=$(cat "${SLURM_JOB_ID}.pid") echo "$(date) ** Sending kill signal to python process $PID **" diff --git a/example/train.py b/example/train.py index 424f4d2..fbbb3fd 100644 --- a/example/train.py +++ b/example/train.py @@ -15,7 +15,7 @@ from torchvision.datasets import MNIST from torchvision.transforms import ToTensor -from wandb_preempt.checkpointer import CheckpointHandler, get_resume_value +from wandb_preempt.checkpointer import Checkpointer, get_resume_value parser = ArgumentParser("Train a simple CNN on MNIST using SGD.") parser.add_argument("--lr", type=float, default=0.01, help="SGD's learning rate.") @@ -59,7 +59,7 @@ # NOTE: Set up a check-pointer which will load and save checkpoints. # Pass the run ID to obtain unique file names for the checkpoints. -checkpoint_handler = CheckpointHandler( +checkpointer = Checkpointer( run.id, model, optimizer, @@ -72,7 +72,7 @@ # NOTE: If existing, load model, optimizer, and learning rate scheduler state from # latest checkpoint, set random number generator states, and recover the epoch to start # training from. Does nothing if there was no checkpoint. -start_epoch = checkpoint_handler.load_latest_checkpoint() +start_epoch = checkpointer.load_latest_checkpoint() # training for epoch in range(start_epoch, args.max_epochs): @@ -91,7 +91,7 @@ "loss": loss.item(), "lr": optimizer.param_groups[0]["lr"], "loss_scale": scaler.get_scale(), - "resumes": checkpoint_handler.num_resumes, + "resumes": checkpointer.num_resumes, } ) @@ -104,13 +104,13 @@ # NOTE Put validation code here # eval(model, ...) - # NOTE Call checkpoint_handler.step() at the end of the epoch to save a checkpoint. + # NOTE Call checkpointer.step() at the end of the epoch to save a checkpoint. # If SLURM sent us a signal that our time for this job is running out, it will now # also take care of pre-empting the wandb job and requeuing the SLURM job, killing # the current python training script to resume with the requeued job. - checkpoint_handler.step() + checkpointer.step() wandb.finish() # NOTE Remove all created checkpoints once we are done training. If you want to # keep the trained model, remove this line. -checkpoint_handler.remove_checkpoints() +checkpointer.remove_checkpoints() diff --git a/test/test___init__.py b/test/test___init__.py index 3ddc032..73cee8e 100644 --- a/test/test___init__.py +++ b/test/test___init__.py @@ -10,6 +10,7 @@ IDS = NAMES +# TODO Remove this function once we have a unit test that uses the checkpointer code @pytest.mark.parametrize("name", NAMES, ids=IDS) def test_hello(name: str): """Test hello function. @@ -20,6 +21,7 @@ def test_hello(name: str): wandb_preempt.hello(name) +# TODO Remove this function once we have a unit test that uses the checkpointer code @pytest.mark.expensive @pytest.mark.parametrize("name", NAMES, ids=IDS) def test_hello_expensive(name: str): diff --git a/wandb_preempt/__init__.py b/wandb_preempt/__init__.py index f5f2027..c67e208 100644 --- a/wandb_preempt/__init__.py +++ b/wandb_preempt/__init__.py @@ -1,6 +1,14 @@ """wandb_preempt library.""" +from wandb_preempt.checkpointer import Checkpointer, get_resume_value +__all__ = [ + "Checkpointer", + "get_resume_value", +] + + +# TODO Remove this function once we have a unit test that uses the checkpointer code def hello(name): """Say hello to a name. diff --git a/wandb_preempt/checkpointer.py b/wandb_preempt/checkpointer.py index 7a4151a..ec9a4fa 100644 --- a/wandb_preempt/checkpointer.py +++ b/wandb_preempt/checkpointer.py @@ -55,16 +55,16 @@ def get_resume_value(verbose: bool = False) -> str: return resume -class CheckpointHandler: +class Checkpointer: """Class for storing, loading, and removing checkpoints. Can be marked as pre-empted by sending a `SIGUSR1` signal to a Python session. How to use this class: - - Create an instance in your training loop, `handler = CheckpointHandler(...)`. - - At the end of each epoch, call `handler.step()` to save a checkpoint. - If the job received the `SIGUSR1` signal, the handler will requeue the at + - Create an instance in your training loop, `checkpointer = Checkpointer(...)`. + - At the end of each epoch, call `checkpointer.step()` to save a checkpoint. + If the job received the `SIGUSR1` signal, the checkpointer will requeue the at the end of its checkpointing step. """ @@ -79,7 +79,7 @@ def __init__( savedir: str = "checkpoints", verbose: bool = False, ) -> None: - """Set up a checkpoint handler. + """Set up a checkpointer. Args: run_id: A unique identifier for this run.