Skip to content

Commit

Permalink
[REF] Improve name, define a preliminary API
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Sep 6, 2024
1 parent 1e2e17a commit 357acc7
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 13 deletions.
2 changes: 1 addition & 1 deletion example/launch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ child="$!"
function term_handler()
{
echo "$(date) ** Job $SLURM_JOB_NAME ($SLURM_JOB_ID) received SIGUSR1 at $(date) **"
# The CheckpointHandler will have written the PID of the Python process to a file
# The Checkpointer will have written the PID of the Python process to a file
# so we can send it the SIGUSR1 signal
PID=$(cat "${SLURM_JOB_ID}.pid")
echo "$(date) ** Sending kill signal to python process $PID **"
Expand Down
14 changes: 7 additions & 7 deletions example/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

from wandb_preempt.checkpointer import CheckpointHandler, get_resume_value
from wandb_preempt.checkpointer import Checkpointer, get_resume_value

parser = ArgumentParser("Train a simple CNN on MNIST using SGD.")
parser.add_argument("--lr", type=float, default=0.01, help="SGD's learning rate.")
Expand Down Expand Up @@ -59,7 +59,7 @@

# NOTE: Set up a check-pointer which will load and save checkpoints.
# Pass the run ID to obtain unique file names for the checkpoints.
checkpoint_handler = CheckpointHandler(
checkpointer = Checkpointer(
run.id,
model,
optimizer,
Expand All @@ -72,7 +72,7 @@
# NOTE: If existing, load model, optimizer, and learning rate scheduler state from
# latest checkpoint, set random number generator states, and recover the epoch to start
# training from. Does nothing if there was no checkpoint.
start_epoch = checkpoint_handler.load_latest_checkpoint()
start_epoch = checkpointer.load_latest_checkpoint()

# training
for epoch in range(start_epoch, args.max_epochs):
Expand All @@ -91,7 +91,7 @@
"loss": loss.item(),
"lr": optimizer.param_groups[0]["lr"],
"loss_scale": scaler.get_scale(),
"resumes": checkpoint_handler.num_resumes,
"resumes": checkpointer.num_resumes,
}
)

Expand All @@ -104,13 +104,13 @@
# NOTE Put validation code here
# eval(model, ...)

# NOTE Call checkpoint_handler.step() at the end of the epoch to save a checkpoint.
# NOTE Call checkpointer.step() at the end of the epoch to save a checkpoint.
# If SLURM sent us a signal that our time for this job is running out, it will now
# also take care of pre-empting the wandb job and requeuing the SLURM job, killing
# the current python training script to resume with the requeued job.
checkpoint_handler.step()
checkpointer.step()

wandb.finish()
# NOTE Remove all created checkpoints once we are done training. If you want to
# keep the trained model, remove this line.
checkpoint_handler.remove_checkpoints()
checkpointer.remove_checkpoints()
2 changes: 2 additions & 0 deletions test/test___init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
IDS = NAMES


# TODO Remove this function once we have a unit test that uses the checkpointer code
@pytest.mark.parametrize("name", NAMES, ids=IDS)
def test_hello(name: str):
"""Test hello function.
Expand All @@ -20,6 +21,7 @@ def test_hello(name: str):
wandb_preempt.hello(name)


# TODO Remove this function once we have a unit test that uses the checkpointer code
@pytest.mark.expensive
@pytest.mark.parametrize("name", NAMES, ids=IDS)
def test_hello_expensive(name: str):
Expand Down
8 changes: 8 additions & 0 deletions wandb_preempt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
"""wandb_preempt library."""

from wandb_preempt.checkpointer import Checkpointer, get_resume_value

__all__ = [
"Checkpointer",
"get_resume_value",
]


# TODO Remove this function once we have a unit test that uses the checkpointer code
def hello(name):
"""Say hello to a name.
Expand Down
10 changes: 5 additions & 5 deletions wandb_preempt/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,16 @@ def get_resume_value(verbose: bool = False) -> str:
return resume


class CheckpointHandler:
class Checkpointer:
"""Class for storing, loading, and removing checkpoints.
Can be marked as pre-empted by sending a `SIGUSR1` signal to a Python session.
How to use this class:
- Create an instance in your training loop, `handler = CheckpointHandler(...)`.
- At the end of each epoch, call `handler.step()` to save a checkpoint.
If the job received the `SIGUSR1` signal, the handler will requeue the at
- Create an instance in your training loop, `checkpointer = Checkpointer(...)`.
- At the end of each epoch, call `checkpointer.step()` to save a checkpoint.
If the job received the `SIGUSR1` signal, the checkpointer will requeue the at
the end of its checkpointing step.
"""

Expand All @@ -79,7 +79,7 @@ def __init__(
savedir: str = "checkpoints",
verbose: bool = False,
) -> None:
"""Set up a checkpoint handler.
"""Set up a checkpointer.
Args:
run_id: A unique identifier for this run.
Expand Down

0 comments on commit 357acc7

Please sign in to comment.