Skip to content

Commit

Permalink
[DEL] Remove get_resume_value
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Sep 9, 2024
1 parent 2184905 commit 1dcc89d
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 43 deletions.
2 changes: 0 additions & 2 deletions docs/api.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
::: wandb_preempt.get_resume_value

::: wandb_preempt.Checkpointer
options:
members:
Expand Down
4 changes: 2 additions & 2 deletions example/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

from wandb_preempt.checkpointer import Checkpointer, get_resume_value
from wandb_preempt.checkpointer import Checkpointer

LOGGING_INTERVAL = 50 # Num batches between logging to stdout and wandb
VERBOSE = True # Enable verbose output
Expand Down Expand Up @@ -46,7 +46,7 @@ def main(args):
DEV = device("cuda" if cuda.is_available() else "cpu")

# NOTE: Figure out the `resume` value and pass it to wandb
run = wandb.init(resume=get_resume_value(verbose=VERBOSE))
run = wandb.init(resume="allow")

# Set up the data, neural net, loss function, and optimizer
train_dataset = MNIST("./data", train=True, download=True, transform=ToTensor())
Expand Down
39 changes: 0 additions & 39 deletions wandb_preempt/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,45 +16,6 @@
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from wandb import Api


def get_resume_value(verbose: bool = False) -> str:
"""Return the `resume` value a the agent's run.
Args:
verbose: Whether to print information to the command line. Default: `False`.
Returns:
The run's resume value. Either `'must'` or `'allow'`.
Raises:
RuntimeError: If the environment variables `WANDB_ENTITY`, `WANDB_PROJECT`, or
`WANDB_RUN_ID`, which are usually set by a wandb agent, are not set.
"""
if verbose:
print("Environment variables containing 'WANDB'")
for key, value in environ.items():
if "WANDB" in key:
print(f"{key}: {value}")

for var in {"WANDB_ENTITY", "WANDB_PROJECT", "WANDB_RUN_ID"}:
if var not in environ:
raise RuntimeError(f"Environment variable {var!r} was not set.")

entity = environ["WANDB_ENTITY"]
project = environ["WANDB_PROJECT"]
run_id = environ["WANDB_RUN_ID"]

run = Api().run(f"{entity}/{project}/{run_id}")
resume = "must" if run.state == "preempted" else "allow"
if verbose:
print(
f"Agent's run has ID {run.id} and state {run.state}."
+ f" Using resume={resume!r}."
)

return resume


class Checkpointer:
Expand Down

0 comments on commit 1dcc89d

Please sign in to comment.