Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Walkthrough #9

Merged
merged 33 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
ca591b5
[DOC] Start writing walkthrough
f-dangel Sep 8, 2024
1304ddb
[DOC] Progress on walkthrough
f-dangel Sep 9, 2024
b09e7f3
Merge branch 'outside-slurm' into walkthrough
f-dangel Sep 9, 2024
36a8081
[DOC] More progress
f-dangel Sep 9, 2024
7bfe1cc
Merge branch 'main' into walkthrough
f-dangel Sep 10, 2024
a94e461
[ADD] Create new sweep because training script args changed
f-dangel Sep 10, 2024
a8c2472
[BUG] Always return tuple in `load_latest_checkpoint`
f-dangel Sep 10, 2024
0945d5c
[DOC] Add some screenshots
f-dangel Sep 10, 2024
a4ce435
[FIX] SLURM array detection
f-dangel Sep 10, 2024
97fb006
[DOC] Link to wandb and Slurm webpages
scottclowe Sep 10, 2024
f841800
[DOC] Fix docstrings
scottclowe Sep 10, 2024
408da21
[DOC] Tidy up doc building instructions
scottclowe Sep 10, 2024
ce204cf
[DOC] Add instructions to clone github repo
scottclowe Sep 10, 2024
c5b244c
[DOC] Walkthrough improvements
scottclowe Sep 10, 2024
65e3391
[MNT] Rename project quickstart -> example-preemptable-sweep
scottclowe Sep 10, 2024
1433b27
[MNT] Use default entity in example sweep
scottclowe Sep 10, 2024
de9b69b
[DOC] Need to show remove_checkpoints in API as we use it in example
scottclowe Sep 10, 2024
095891c
[BUG] Need to load checkpoint step count from file
scottclowe Sep 11, 2024
fef68fc
[DOC] Update docs
scottclowe Sep 11, 2024
71564c6
[ENH] Record epoch on wandb.log
scottclowe Sep 11, 2024
fb0e641
[BUG] Don't duplicate lr name with config and measure fields
scottclowe Sep 11, 2024
88664d9
[MNT] Import torch.amp instead of deprecated torch.cuda.amp
scottclowe Sep 11, 2024
1863820
[MNT] Increase epochs 10 -> 15
scottclowe Sep 11, 2024
289fbb6
[BUG] Need to load step_count with one more than step count because w…
scottclowe Sep 11, 2024
5e6c505
[DOC] Change to use lr_max in walkthrough
scottclowe Sep 11, 2024
7b5c5bd
Revert "[MNT] Import torch.amp instead of deprecated torch.cuda.amp"
scottclowe Sep 11, 2024
cf45b0a
[MNT] Change preempting print statement
scottclowe Sep 11, 2024
de0fad0
[DOC] Update README
f-dangel Sep 11, 2024
bbe7c02
[DEL] Remove unbuffered flag
f-dangel Sep 11, 2024
ebd88d7
[REF] Same epoch default value (20) for train script and sweep
f-dangel Sep 11, 2024
581806e
[ADD] Update sweep id
f-dangel Sep 11, 2024
a2b8f0b
[DOC] More screenshots, roughly complete content
f-dangel Sep 11, 2024
4b61859
Merge branch 'main' into walkthrough
f-dangel Sep 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@

[![Documentation Status](https://readthedocs.org/projects/wandb-preempt/badge/?version=latest)](https://wandb-preempt.readthedocs.io/en/latest/?badge=latest)

This repository contains a tutorial on how to combine `wandb` sweeps with
Slurm's pre-emption, i.e. how to automatically re-queue and resume runs from a
Weights & Biases sweep on a Slurm cluster.
This repository contains a tutorial on how to combine [wandb](https://wandb.ai/) sweeps
with [Slurm](https://slurm.schedmd.com/)'s pre-emption, i.e. how to automatically
re-queue and resume runs from a Weights & Biases sweep on a Slurm cluster.

This is work in progress:
## Getting started

- TODO Debug failure scenarios

- TODO Write a self-contained walk-through.

## Installation
### Installation

```bash
pip install git+https://github.com/f-dangel/wandb_preempt.git@main
```

### Basic Example

Please see the [docs](https://wandb-preempt.readthedocs.io/en/latest/walkthrough/).
Empty file removed docs/README.md
Empty file.
1 change: 1 addition & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
- __init__
- load_latest_checkpoint
- step
- remove_checkpoints
Binary file added docs/assets/01_empty_sweep.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/assets/02_local_run.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/assets/03_slurm_preempted.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/assets/04_slurm_finished.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 6 additions & 2 deletions docs/develop.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,9 @@ via `make`:

We use the [Google docstring
convention](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html)
and `mkdocs` which allows using markdown syntax in a docstring to achieve
formatting. To build the docs, run `mkdocs serve` in the repository root.
and [mkdocs](https://www.mkdocs.org/) which allows using markdown syntax in a docstring
to achieve formatting. To build the docs, install the documentation requirements
```bash
pip install -e ."[doc]"
```
and run `mkdocs serve` from the repository root.
1 change: 1 addition & 0 deletions docs/walkthrough.md
334 changes: 334 additions & 0 deletions example/README.md

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions example/launch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
#SBATCH --gres=gpu:1
#SBATCH --cpus-per-gpu=4
#SBATCH --mem-per-gpu=16G
#SBATCH --time=00:04:00
#SBATCH --qos=m5
#SBATCH --array=0-19
#SBATCH --signal=B:SIGUSR1@120 # Send signal SIGUSR1 120 seconds before the job hits the time limit
#SBATCH --open-mode=append
#SBATCH --time=00:04:00
#SBATCH --array=0-9
#SBATCH --signal=B:SIGUSR1@120 # Send signal SIGUSR1 120 seconds before the job hits the time limit

echo "Job $SLURM_JOB_NAME ($SLURM_JOB_ID) begins on $(hostname), submitted from $SLURM_SUBMIT_HOST ($SLURM_CLUSTER_NAME)"
echo ""
Expand All @@ -20,7 +20,7 @@ if [ "$SLURM_ARRAY_TASK_COUNT" != "" ]; then
fi

# NOTE that we need to use srun here, otherwise the Python process won't receive the SIGUSR1 signal
srun --unbuffered wandb agent --count=1 f-dangel-team/quickstart/i75puhon &
srun wandb agent --count=1 f-dangel-team/example-preemptable-sweep/4m89qo6r &
child="$!"

# Set up a handler to pass the SIGUSR1 to the python session launched by the agent
Expand Down
9 changes: 4 additions & 5 deletions example/sweep.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
entity: f-dangel-team
project: quickstart
project: example-preemptable-sweep
name: SGD
program: train.py
command:
Expand All @@ -12,9 +11,9 @@ metric:
name: loss
method: random
parameters:
lr:
lr_max:
distribution: log_uniform_values
min: 1e-3
max: 1e-1
max_epochs:
value: 150
epochs:
value: 20
11 changes: 6 additions & 5 deletions example/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ def get_parser():
r"""Create argument parser."""
parser = ArgumentParser("Train a simple CNN on MNIST using SGD.")
parser.add_argument(
"--lr", type=float, default=0.01, help="Learning rate. Default: %(default)s"
"--lr_max", type=float, default=0.01, help="Learning rate. Default: %(default)s"
)
parser.add_argument(
"--epochs", type=int, default=10, help="Number of epochs. Default: %(default)s"
"--epochs", type=int, default=20, help="Number of epochs. Default: %(default)s"
)
parser.add_argument(
"--batch_size", type=int, default=256, help="Batch size. Default: %(default)s"
Expand Down Expand Up @@ -62,8 +62,8 @@ def main(args):
Linear(50, 10),
).to(DEV)
loss_func = CrossEntropyLoss().to(DEV)
print(f"Using SGD with learning rate {args.lr}.")
optimizer = SGD(model.parameters(), lr=args.lr)
print(f"Using SGD with learning rate {args.lr_max}.")
optimizer = SGD(model.parameters(), lr=args.lr_max)
lr_scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs)
scaler = GradScaler()

Expand All @@ -82,7 +82,7 @@ def main(args):
# NOTE: If existing, load model, optimizer, and learning rate scheduler state from
# latest checkpoint, set random number generator states, and recover the epoch to
# start training from. Does nothing if there was no checkpoint.
start_epoch = checkpointer.load_latest_checkpoint()
start_epoch, _ = checkpointer.load_latest_checkpoint()

# training
for epoch in range(start_epoch, args.epochs):
Expand All @@ -101,6 +101,7 @@ def main(args):
"loss": loss.item(),
"lr": optimizer.param_groups[0]["lr"],
"loss_scale": scaler.get_scale(),
"epoch": epoch,
"resumes": checkpointer.num_resumes,
}
)
Expand Down
2 changes: 2 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ watch:
nav:
- Getting Started: index.md
- API Documentation: api.md
- Walkthrough Example: walkthrough.md
- Developer Notes: develop.md
theme:
name: material
Expand All @@ -24,6 +25,7 @@ markdown_extensions:
- pymdownx.inlinehilite # code highlighting
- pymdownx.snippets # code highlighting
- pymdownx.superfences # code highlighting
- pymdownx.blocks.details # fold-able content
- footnotes
plugins:
- mkdocstrings:
Expand Down
25 changes: 14 additions & 11 deletions wandb_preempt/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ class Checkpointer:

How to use this class:

- Create an instance in your training loop, `checkpointer = Checkpointer(...)`.
- Create an instance of this class `checkpointer = Checkpointer(...)`.
- At the end of each epoch, call `checkpointer.step()` to save a checkpoint.
If the job received the `SIGUSR1` signal, the checkpointer will requeue the at
the end of its checkpointing step.
If the job received the `SIGUSR1` or `SIGTERM` signal, the checkpointer will
requeue the Slurm job at the end of its checkpointing step.
"""

def __init__(
Expand Down Expand Up @@ -113,7 +113,8 @@ def mark_preempted(self, sig: int, frame: Optional[FrameType]):
frame: The current stack frame.
"""
self.maybe_print(
f"Got signal {sig}. Marking as pre-empted. This will be the last epoch."
f"Received signal {sig}. Marking as pre-empted and will halt and requeue"
" the job at next call of checkpointer.step()."
)
self.marked_preempted = True

Expand Down Expand Up @@ -194,13 +195,14 @@ def load_latest_checkpoint(
**kwargs: Additional keyword arguments to pass to the `torch.load` function.

Returns:
The epoch number at which training should resume, and the extra information
that was passed by the user as a dictionary to the :meth:`step` function.
epoch: The epoch number at which training should resume.
extra_info: Extra information that was passed by the user to the `step`
function.
"""
loadpath = self.latest_checkpoint()
if loadpath is None:
self.maybe_print("No checkpoint found. Starting from scratch.")
return 0
return 0, {}

self.maybe_print(f"Loading checkpoint {loadpath}.")

Expand All @@ -216,7 +218,7 @@ def load_latest_checkpoint(
self.maybe_print("Loading gradient scaler.")
self.scaler.load_state_dict(data["scaler"])

self.step_count = data["checkpoint_step"]
self.step_count = data["checkpoint_step"] + 1
self.num_resumes = data["resumes"] + 1

# restore random number generator states for all devices
Expand All @@ -227,7 +229,7 @@ def load_latest_checkpoint(
else:
set_rng_state(rng_state)

return self.step_count + 1, data["extra_info"]
return self.step_count, data["extra_info"]

def remove_checkpoints(self, keep_latest: bool = False):
"""Remove checkpoints.
Expand Down Expand Up @@ -317,7 +319,7 @@ def maybe_requeue_slurm_job(self):
array_id = getenv("SLURM_ARRAY_JOB_ID")
task_id = getenv("SLURM_ARRAY_TASK_ID")

uses_array = array_id is None and task_id is None
uses_array = array_id is not None and task_id is not None
requeue_id = f"{array_id}_{task_id}" if uses_array else job_id

cmd = ["scontrol", "requeue", requeue_id]
Expand All @@ -342,7 +344,8 @@ def step(self, extra_info: Optional[Dict] = None):
Args:
extra_info: Additional information to save in the checkpoint. This
dictionary is returned when loading the latest checkpoint with
:meth:`load_latest_checkpoint`. Default: `None` (empty dictionary).
`checkpointer.load_latest_checkpoint()`.
By default, an empty dictionary is saved.
"""
self.save_checkpoint({} if extra_info is None else extra_info)
# Remove stale checkpoints
Expand Down