diff --git a/example/launch.sh b/example/launch.sh index d2bd5d6..b5fcaaa 100644 --- a/example/launch.sh +++ b/example/launch.sh @@ -8,8 +8,8 @@ #SBATCH --qos=m5 #SBATCH --open-mode=append #SBATCH --time=00:04:00 -#SBATCH --array=0-9 -#SBATCH --signal=B:SIGUSR1@120 # Send signal SIGUSR1 120 seconds before the job hits the time limit +#SBATCH --array=0 +#SBATCH --signal=B:SIGUSR1@150 # Send signal SIGUSR1 120 seconds before the job hits the time limit echo "Job $SLURM_JOB_NAME ($SLURM_JOB_ID) begins on $(hostname), submitted from $SLURM_SUBMIT_HOST ($SLURM_CLUSTER_NAME)" echo "" @@ -20,7 +20,7 @@ if [ "$SLURM_ARRAY_TASK_COUNT" != "" ]; then fi # NOTE that we need to use srun here, otherwise the Python process won't receive the SIGUSR1 signal -srun wandb agent --count=1 f-dangel-team/example-preemptable-sweep/4m89qo6r & +srun --unbuffered wandb agent --count=1 f-dangel-team/example-preemptable-sweep/4m89qo6r & child="$!" # Set up a handler to pass the SIGUSR1 to the python session launched by the agent diff --git a/example/train.py b/example/train.py index 6dfda17..c5c8de1 100755 --- a/example/train.py +++ b/example/train.py @@ -5,6 +5,7 @@ """ from argparse import ArgumentParser +from os import environ import wandb from torch import autocast, bfloat16, cuda, device, manual_seed @@ -45,9 +46,6 @@ def main(args): manual_seed(0) # make deterministic DEV = device("cuda" if cuda.is_available() else "cpu") - # NOTE: Allow runs to resume by passing 'allow' to wandb - run = wandb.init(resume="allow") - # Set up the data, neural net, loss function, and optimizer train_dataset = MNIST("./data", train=True, download=True, transform=ToTensor()) train_loader = DataLoader( @@ -69,8 +67,9 @@ def main(args): # NOTE: Set up a check-pointer which will load and save checkpoints. # Pass the run ID to obtain unique file names for the checkpoints. + run_id = environ["WANDB_RUN_ID"] checkpointer = Checkpointer( - run.id, + run_id, model, optimizer, lr_scheduler=lr_scheduler, @@ -82,10 +81,27 @@ def main(args): # NOTE: If existing, load model, optimizer, and learning rate scheduler state from # latest checkpoint, set random number generator states. If there was no checkpoint # to load, it does nothing and returns `None` for the step count. - checkpoint_index, _ = checkpointer.load_latest_checkpoint() + checkpoint_index, extra_info = checkpointer.load_latest_checkpoint() # Select the remaining epochs to train start_epoch = 0 if checkpoint_index is None else checkpoint_index + 1 + # NOTE forking must be enabled by the wandb team for your project. + wandb_resume_step = extra_info.get("wandb_step", None) + resume_from = ( + None if wandb_resume_step is None else f"{run_id}?_step={wandb_resume_step}" + ) + resume = "allow" if resume_from is None else None + print("resume_from:", resume_from) + print("resume:", resume) + wandb.init(resume=resume, resume_from=resume_from) + + # NOTE: Allow runs to resume by passing 'allow' to wandb + # wandb.init(resume="allow") + # print("Wandb step before manually setting it:", wandb.run.step) + # NOTE: Currently getting an error from setattr here + # wandb.run.step = extra_info.get("wandb_step", 0) + # print("Wandb step after manually setting it:", wandb.run.step) + # training for epoch in range(start_epoch, args.epochs): model.train() @@ -122,7 +138,7 @@ def main(args): # running out, it will now also take care of pre-empting the wandb job # and requeuing the SLURM job, killing the current python training script # to resume with the requeued job. - checkpointer.step() + checkpointer.step(extra_info={"wandb_step": wandb.run.step}) wandb.finish() # NOTE Remove all created checkpoints once we are done training. If you want to