Skip to content

Commit

Permalink
Merge pull request #7 from f-dangel/outside-slurm
Browse files Browse the repository at this point in the history
[ADD] Make checkpointer work outside a SLURM session
  • Loading branch information
scottclowe authored Sep 10, 2024
2 parents 15a8452 + 2184905 commit 6bf7881
Showing 1 changed file with 32 additions and 26 deletions.
58 changes: 32 additions & 26 deletions wandb_preempt/checkpointer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Class for handling checkpointing."""

from datetime import datetime
from datetime import date, datetime
from glob import glob
from os import environ, getenv, getpid, makedirs, path, remove, rename
from signal import SIGTERM, SIGUSR1, signal
Expand Down Expand Up @@ -93,10 +93,6 @@ def __init__(
savedir: Directory to store checkpoints in. Default: `'checkpoints'`.
verbose: Whether to print messages about saving and loading checkpoints.
Default: `False`
Raises:
RuntimeError: If the environment variable `SLURM_JOB_ID` is not set.
This indicates we are not running a SLURM task array job.
"""
self.time_created = time()
self.run_id = run_id
Expand All @@ -118,23 +114,32 @@ def __init__(
self.savedir = path.abspath(savedir)
self.maybe_print(f"Creating checkpoint directory: {self.savedir}.")
makedirs(self.savedir, exist_ok=True)
self.savedir_job = path.join(self.savedir, environ["SLURM_JOB_ID"])

# write Python PID to a file so it can be read by the signal handler from the
# sbatch script, because it has to send a kill signal with SIGUSR1 to that PID.
# Detect whether we are running inside a SLURM session
job_id = getenv("SLURM_JOB_ID")
array_id = getenv("SLURM_ARRAY_JOB_ID")
task_id = getenv("SLURM_ARRAY_TASK_ID")
self.maybe_print(f"Job ID: {job_id}, Array ID: {array_id}, Task ID: {task_id}")

if job_id is None:
raise RuntimeError("SLURM_JOB_ID is not set.")
self.maybe_print(
f"SLURM job ID: {job_id}, array ID: {array_id}, task ID: {task_id}"
)
self.uses_slurm = any(var is not None for var in {job_id, array_id, task_id})

# We will create sub-folders in the directory supplied by the user where
# checkpoints are stored. If we are on SLURM, we will use the `SLURM_JOB_ID`
# variable as name, otherwise we will use the formatted day.
self.savedir_job = path.join(
self.savedir,
f"{environ['SLURM_JOB_ID'] if self.uses_slurm else date.today()}",
)

filename = f"{job_id}.pid"
pid = str(getpid())
self.maybe_print(f"Writing PID {pid} to file {filename}.")
with open(filename, "w") as f:
f.write(pid)
# write Python PID to a file so it can be read by the signal handler from the
# sbatch script, because it has to send a kill signal with SIGUSR1 to that PID.
if self.uses_slurm:
filename = f"{job_id}.pid"
pid = str(getpid())
self.maybe_print(f"Writing PID {pid} to file {filename}.")
with open(filename, "w") as f:
f.write(pid)

def mark_preempted(self, sig: int, frame: Optional[FrameType]):
"""Mark the checkpointer as pre-empted.
Expand Down Expand Up @@ -330,18 +335,19 @@ def maybe_print(self, msg: str, verbose: Optional[bool] = None) -> None:
elapsed = time() - self.time_created
print(f"[{elapsed:.1f} s | {datetime.now()}] {msg}")

def requeue_slurm_job(self):
"""Requeue the Slurm job.
def maybe_requeue_slurm_job(self):
"""Requeue the SLURM job if we are running in a SLURM session."""
if not self.uses_slurm:
return

Raises:
RuntimeError: If the job is not a Slurm job.
"""
job_id = getenv("SLURM_JOB_ID")
array_id = getenv("SLURM_ARRAY_JOB_ID")
task_id = getenv("SLURM_ARRAY_TASK_ID")

if job_id is None:
raise RuntimeError("Not a SLURM job. Variable SLURM_JOB_ID not set.")
uses_array = array_id is None and task_id is None
requeue_id = f"{array_id}_{task_id}" if uses_array else job_id

cmd = ["scontrol", "requeue", job_id]
cmd = ["scontrol", "requeue", requeue_id]
self.maybe_print(f"Requeuing SLURM job with `{' '.join(cmd)}`.")
run(cmd, check=True)

Expand Down Expand Up @@ -373,7 +379,7 @@ def step(self, extra_info: Optional[Dict] = None):
if self.marked_preempted:
self.maybe_print("Run was marked as pre-empted via signal.")
self.preempt_wandb_run()
self.requeue_slurm_job()
self.maybe_requeue_slurm_job()
self.maybe_print("Exiting with error code 1.")
exit(1)
# Increase the number of steps taken
Expand Down

0 comments on commit 6bf7881

Please sign in to comment.