From e28ffd22d01fbeae24a01390e62fe8931ae123c1 Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Sat, 14 Sep 2024 11:15:02 -0400 Subject: [PATCH 1/3] [ADD] Save and restore wandb step --- example/train.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/example/train.py b/example/train.py index 6dfda17..e98e8c5 100755 --- a/example/train.py +++ b/example/train.py @@ -5,6 +5,7 @@ """ from argparse import ArgumentParser +from os import environ import wandb from torch import autocast, bfloat16, cuda, device, manual_seed @@ -45,9 +46,6 @@ def main(args): manual_seed(0) # make deterministic DEV = device("cuda" if cuda.is_available() else "cpu") - # NOTE: Allow runs to resume by passing 'allow' to wandb - run = wandb.init(resume="allow") - # Set up the data, neural net, loss function, and optimizer train_dataset = MNIST("./data", train=True, download=True, transform=ToTensor()) train_loader = DataLoader( @@ -69,8 +67,9 @@ def main(args): # NOTE: Set up a check-pointer which will load and save checkpoints. # Pass the run ID to obtain unique file names for the checkpoints. + run_id = environ["WANDB_RUN_ID"] checkpointer = Checkpointer( - run.id, + run_id, model, optimizer, lr_scheduler=lr_scheduler, @@ -82,10 +81,19 @@ def main(args): # NOTE: If existing, load model, optimizer, and learning rate scheduler state from # latest checkpoint, set random number generator states. If there was no checkpoint # to load, it does nothing and returns `None` for the step count. - checkpoint_index, _ = checkpointer.load_latest_checkpoint() + checkpoint_index, extra_info = checkpointer.load_latest_checkpoint() # Select the remaining epochs to train start_epoch = 0 if checkpoint_index is None else checkpoint_index + 1 + wandb_resume_step = extra_info.get("wandb_step", None) + resume_from = ( + None if wandb_resume_step is None else f"{run_id}?_step={wandb_resume_step}" + ) + print("Resume string:", resume_from) + + # NOTE: Allow runs to resume by passing 'allow' to wandb + wandb.init(resume="allow", resume_from=resume_from) + # training for epoch in range(start_epoch, args.epochs): model.train() @@ -122,7 +130,7 @@ def main(args): # running out, it will now also take care of pre-empting the wandb job # and requeuing the SLURM job, killing the current python training script # to resume with the requeued job. - checkpointer.step() + checkpointer.step(extra_info={"wandb_step": wandb.run.step}) wandb.finish() # NOTE Remove all created checkpoints once we are done training. If you want to From 05152ea62fa6b3fd4a0d575507a72b04ba448852 Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Sat, 14 Sep 2024 11:38:26 -0400 Subject: [PATCH 2/3] Try setting wandb step manually --- example/launch.sh | 6 +++--- example/train.py | 20 ++++++++++++++------ 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/example/launch.sh b/example/launch.sh index d2bd5d6..b5fcaaa 100644 --- a/example/launch.sh +++ b/example/launch.sh @@ -8,8 +8,8 @@ #SBATCH --qos=m5 #SBATCH --open-mode=append #SBATCH --time=00:04:00 -#SBATCH --array=0-9 -#SBATCH --signal=B:SIGUSR1@120 # Send signal SIGUSR1 120 seconds before the job hits the time limit +#SBATCH --array=0 +#SBATCH --signal=B:SIGUSR1@150 # Send signal SIGUSR1 120 seconds before the job hits the time limit echo "Job $SLURM_JOB_NAME ($SLURM_JOB_ID) begins on $(hostname), submitted from $SLURM_SUBMIT_HOST ($SLURM_CLUSTER_NAME)" echo "" @@ -20,7 +20,7 @@ if [ "$SLURM_ARRAY_TASK_COUNT" != "" ]; then fi # NOTE that we need to use srun here, otherwise the Python process won't receive the SIGUSR1 signal -srun wandb agent --count=1 f-dangel-team/example-preemptable-sweep/4m89qo6r & +srun --unbuffered wandb agent --count=1 f-dangel-team/example-preemptable-sweep/4m89qo6r & child="$!" # Set up a handler to pass the SIGUSR1 to the python session launched by the agent diff --git a/example/train.py b/example/train.py index e98e8c5..4a46b49 100755 --- a/example/train.py +++ b/example/train.py @@ -85,14 +85,22 @@ def main(args): # Select the remaining epochs to train start_epoch = 0 if checkpoint_index is None else checkpoint_index + 1 - wandb_resume_step = extra_info.get("wandb_step", None) - resume_from = ( - None if wandb_resume_step is None else f"{run_id}?_step={wandb_resume_step}" - ) - print("Resume string:", resume_from) + # NOTE forking must be enabled by the wandb team for your project. + # wandb_resume_step = extra_info.get("wandb_step", None) + # resume_from = ( + # None if wandb_resume_step is None else f"{run_id}?_step={wandb_resume_step}" + # ) + # resume = "allow" if resume_from is None else None + # print("resume_from:", resume_from) + # print("resume:", resume) + # wandb.init(resume=resume, resume_from=resume_from) # NOTE: Allow runs to resume by passing 'allow' to wandb - wandb.init(resume="allow", resume_from=resume_from) + wandb.init(resume="allow") + print("Wandb step before manually setting it:", wandb.run.step) + # NOTE: Currently getting an error from setattr here + wandb.run.step = extra_info.get("wandb_step", 0) + print("Wandb step after manually setting it:", wandb.run.step) # training for epoch in range(start_epoch, args.epochs): From b9cbf9e0b8d053d772a1a2f30be82e974fd12da2 Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Wed, 25 Sep 2024 09:22:12 -0400 Subject: [PATCH 3/3] Try using `resume_from` argument --- example/train.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/example/train.py b/example/train.py index 4a46b49..c5c8de1 100755 --- a/example/train.py +++ b/example/train.py @@ -86,21 +86,21 @@ def main(args): start_epoch = 0 if checkpoint_index is None else checkpoint_index + 1 # NOTE forking must be enabled by the wandb team for your project. - # wandb_resume_step = extra_info.get("wandb_step", None) - # resume_from = ( - # None if wandb_resume_step is None else f"{run_id}?_step={wandb_resume_step}" - # ) - # resume = "allow" if resume_from is None else None - # print("resume_from:", resume_from) - # print("resume:", resume) - # wandb.init(resume=resume, resume_from=resume_from) + wandb_resume_step = extra_info.get("wandb_step", None) + resume_from = ( + None if wandb_resume_step is None else f"{run_id}?_step={wandb_resume_step}" + ) + resume = "allow" if resume_from is None else None + print("resume_from:", resume_from) + print("resume:", resume) + wandb.init(resume=resume, resume_from=resume_from) # NOTE: Allow runs to resume by passing 'allow' to wandb - wandb.init(resume="allow") - print("Wandb step before manually setting it:", wandb.run.step) + # wandb.init(resume="allow") + # print("Wandb step before manually setting it:", wandb.run.step) # NOTE: Currently getting an error from setattr here - wandb.run.step = extra_info.get("wandb_step", 0) - print("Wandb step after manually setting it:", wandb.run.step) + # wandb.run.step = extra_info.get("wandb_step", 0) + # print("Wandb step after manually setting it:", wandb.run.step) # training for epoch in range(start_epoch, args.epochs):