Skip to content

Commit

Permalink
timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Jul 20, 2023
1 parent 6eda824 commit b4fdc6f
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
7 changes: 3 additions & 4 deletions olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
move_to_device,
peak_gpu_memory,
syncronize_flag,
wait_on,
)

__all__ = ["SpeedMonitor", "LRMonitor", "Trainer"]
Expand Down Expand Up @@ -254,8 +255,7 @@ def save_sharded_checkpoint(self) -> Path:
# replacing the temp directory with the final directory from rank 0 might not be immediately
# realized in the file systems of the other ranks.
# So we wait here across all ranks until that final checkpoint directory is visible.
while not checkpoint_dir.exists():
time.sleep(0.5)
wait_on(lambda: checkpoint_dir.exists(), "Waiting for checkpoint directory", timeout=10.0)

# Remove old checkpoints.
if self.cfg.save_num_checkpoints_to_keep > 0:
Expand Down Expand Up @@ -401,8 +401,7 @@ def save_unsharded_checkpoint(self) -> Path:
# replacing the temp directory with the final directory from rank 0 might not be immediately
# realized in the file systems of the other ranks.
# So we wait here across all ranks until that final checkpoint directory is visible.
while not checkpoint_dir.exists():
time.sleep(0.5)
wait_on(lambda: checkpoint_dir.exists(), "Waiting for checkpoint directory", timeout=10.0)

# Remove old checkpoints.
if self.cfg.save_num_unsharded_checkpoints_to_keep > 0:
Expand Down
12 changes: 11 additions & 1 deletion olmo/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import os
import socket
import sys
import time
import warnings
from datetime import datetime
from typing import Any, Dict, Optional, TypeVar, Union
from typing import Any, Callable, Dict, Optional, TypeVar, Union

import rich
import torch
Expand Down Expand Up @@ -339,3 +340,12 @@ def syncronize_flag(flag: bool, device: torch.device) -> bool:
return flag_tensor.item() # type: ignore
else:
return flag


def wait_on(condition: Callable[[], bool], description: str, timeout: float = 10.0):
"""Wait on the condition function to return True."""
start_time = time.monotonic()
while not condition():
time.sleep(0.5)
if time.monotonic() - start_time > timeout:
raise TimeoutError(f"{description} timed out")

0 comments on commit b4fdc6f

Please sign in to comment.