diff --git a/wandb_preempt/checkpointer.py b/wandb_preempt/checkpointer.py index 4df26ac..2e4e4e0 100644 --- a/wandb_preempt/checkpointer.py +++ b/wandb_preempt/checkpointer.py @@ -1,6 +1,6 @@ """Class for handling checkpointing.""" -from datetime import datetime +from datetime import date, datetime from glob import glob from os import environ, getenv, getpid, makedirs, path, remove, rename from signal import SIGTERM, SIGUSR1, signal @@ -93,10 +93,6 @@ def __init__( savedir: Directory to store checkpoints in. Default: `'checkpoints'`. verbose: Whether to print messages about saving and loading checkpoints. Default: `False` - - Raises: - RuntimeError: If the environment variable `SLURM_JOB_ID` is not set. - This indicates we are not running a SLURM task array job. """ self.time_created = time() self.run_id = run_id @@ -118,23 +114,32 @@ def __init__( self.savedir = path.abspath(savedir) self.maybe_print(f"Creating checkpoint directory: {self.savedir}.") makedirs(self.savedir, exist_ok=True) - self.savedir_job = path.join(self.savedir, environ["SLURM_JOB_ID"]) - # write Python PID to a file so it can be read by the signal handler from the - # sbatch script, because it has to send a kill signal with SIGUSR1 to that PID. + # Detect whether we are running inside a SLURM session job_id = getenv("SLURM_JOB_ID") array_id = getenv("SLURM_ARRAY_JOB_ID") task_id = getenv("SLURM_ARRAY_TASK_ID") - self.maybe_print(f"Job ID: {job_id}, Array ID: {array_id}, Task ID: {task_id}") - - if job_id is None: - raise RuntimeError("SLURM_JOB_ID is not set.") + self.maybe_print( + f"SLURM job ID: {job_id}, array ID: {array_id}, task ID: {task_id}" + ) + self.uses_slurm = any(var is not None for var in {job_id, array_id, task_id}) + + # We will create sub-folders in the directory supplied by the user where + # checkpoints are stored. If we are on SLURM, we will use the `SLURM_JOB_ID` + # variable as name, otherwise we will use the formatted day. + self.savedir_job = path.join( + self.savedir, + f"{environ['SLURM_JOB_ID'] if self.uses_slurm else date.today()}", + ) - filename = f"{job_id}.pid" - pid = str(getpid()) - self.maybe_print(f"Writing PID {pid} to file {filename}.") - with open(filename, "w") as f: - f.write(pid) + # write Python PID to a file so it can be read by the signal handler from the + # sbatch script, because it has to send a kill signal with SIGUSR1 to that PID. + if self.uses_slurm: + filename = f"{job_id}.pid" + pid = str(getpid()) + self.maybe_print(f"Writing PID {pid} to file {filename}.") + with open(filename, "w") as f: + f.write(pid) def mark_preempted(self, sig: int, frame: Optional[FrameType]): """Mark the checkpointer as pre-empted. @@ -330,18 +335,19 @@ def maybe_print(self, msg: str, verbose: Optional[bool] = None) -> None: elapsed = time() - self.time_created print(f"[{elapsed:.1f} s | {datetime.now()}] {msg}") - def requeue_slurm_job(self): - """Requeue the Slurm job. + def maybe_requeue_slurm_job(self): + """Requeue the SLURM job if we are running in a SLURM session.""" + if not self.uses_slurm: + return - Raises: - RuntimeError: If the job is not a Slurm job. - """ job_id = getenv("SLURM_JOB_ID") + array_id = getenv("SLURM_ARRAY_JOB_ID") + task_id = getenv("SLURM_ARRAY_TASK_ID") - if job_id is None: - raise RuntimeError("Not a SLURM job. Variable SLURM_JOB_ID not set.") + uses_array = array_id is None and task_id is None + requeue_id = f"{array_id}_{task_id}" if uses_array else job_id - cmd = ["scontrol", "requeue", job_id] + cmd = ["scontrol", "requeue", requeue_id] self.maybe_print(f"Requeuing SLURM job with `{' '.join(cmd)}`.") run(cmd, check=True) @@ -373,7 +379,7 @@ def step(self, extra_info: Optional[Dict] = None): if self.marked_preempted: self.maybe_print("Run was marked as pre-empted via signal.") self.preempt_wandb_run() - self.requeue_slurm_job() + self.maybe_requeue_slurm_job() self.maybe_print("Exiting with error code 1.") exit(1) # Increase the number of steps taken