Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADD] Make checkpointer work outside a SLURM session #7

Merged
merged 3 commits into from
Sep 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 32 additions & 26 deletions wandb_preempt/checkpointer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Class for handling checkpointing."""

from datetime import datetime
from datetime import date, datetime
from glob import glob
from os import environ, getenv, getpid, makedirs, path, remove, rename
from signal import SIGTERM, SIGUSR1, signal
Expand Down Expand Up @@ -93,10 +93,6 @@ def __init__(
savedir: Directory to store checkpoints in. Default: `'checkpoints'`.
verbose: Whether to print messages about saving and loading checkpoints.
Default: `False`

Raises:
RuntimeError: If the environment variable `SLURM_JOB_ID` is not set.
This indicates we are not running a SLURM task array job.
"""
self.time_created = time()
self.run_id = run_id
Expand All @@ -118,23 +114,32 @@ def __init__(
self.savedir = path.abspath(savedir)
self.maybe_print(f"Creating checkpoint directory: {self.savedir}.")
makedirs(self.savedir, exist_ok=True)
self.savedir_job = path.join(self.savedir, environ["SLURM_JOB_ID"])

# write Python PID to a file so it can be read by the signal handler from the
# sbatch script, because it has to send a kill signal with SIGUSR1 to that PID.
# Detect whether we are running inside a SLURM session
job_id = getenv("SLURM_JOB_ID")
array_id = getenv("SLURM_ARRAY_JOB_ID")
task_id = getenv("SLURM_ARRAY_TASK_ID")
self.maybe_print(f"Job ID: {job_id}, Array ID: {array_id}, Task ID: {task_id}")

if job_id is None:
raise RuntimeError("SLURM_JOB_ID is not set.")
self.maybe_print(
f"SLURM job ID: {job_id}, array ID: {array_id}, task ID: {task_id}"
)
self.uses_slurm = any(var is not None for var in {job_id, array_id, task_id})

# We will create sub-folders in the directory supplied by the user where
# checkpoints are stored. If we are on SLURM, we will use the `SLURM_JOB_ID`
# variable as name, otherwise we will use the formatted day.
self.savedir_job = path.join(
self.savedir,
f"{environ['SLURM_JOB_ID'] if self.uses_slurm else date.today()}",
)

filename = f"{job_id}.pid"
pid = str(getpid())
self.maybe_print(f"Writing PID {pid} to file {filename}.")
with open(filename, "w") as f:
f.write(pid)
# write Python PID to a file so it can be read by the signal handler from the
# sbatch script, because it has to send a kill signal with SIGUSR1 to that PID.
if self.uses_slurm:
filename = f"{job_id}.pid"
pid = str(getpid())
self.maybe_print(f"Writing PID {pid} to file {filename}.")
with open(filename, "w") as f:
f.write(pid)

def mark_preempted(self, sig: int, frame: Optional[FrameType]):
"""Mark the checkpointer as pre-empted.
Expand Down Expand Up @@ -330,18 +335,19 @@ def maybe_print(self, msg: str, verbose: Optional[bool] = None) -> None:
elapsed = time() - self.time_created
print(f"[{elapsed:.1f} s | {datetime.now()}] {msg}")

def requeue_slurm_job(self):
"""Requeue the Slurm job.
def maybe_requeue_slurm_job(self):
"""Requeue the SLURM job if we are running in a SLURM session."""
if not self.uses_slurm:
return

Raises:
RuntimeError: If the job is not a Slurm job.
"""
job_id = getenv("SLURM_JOB_ID")
array_id = getenv("SLURM_ARRAY_JOB_ID")
task_id = getenv("SLURM_ARRAY_TASK_ID")

if job_id is None:
raise RuntimeError("Not a SLURM job. Variable SLURM_JOB_ID not set.")
uses_array = array_id is None and task_id is None
requeue_id = f"{array_id}_{task_id}" if uses_array else job_id

cmd = ["scontrol", "requeue", job_id]
cmd = ["scontrol", "requeue", requeue_id]
self.maybe_print(f"Requeuing SLURM job with `{' '.join(cmd)}`.")
run(cmd, check=True)

Expand Down Expand Up @@ -373,7 +379,7 @@ def step(self, extra_info: Optional[Dict] = None):
if self.marked_preempted:
self.maybe_print("Run was marked as pre-empted via signal.")
self.preempt_wandb_run()
self.requeue_slurm_job()
self.maybe_requeue_slurm_job()
self.maybe_print("Exiting with error code 1.")
exit(1)
# Increase the number of steps taken
Expand Down
Loading