Skip to content

Commit

Permalink
Add support for remote checkpoints and train data files (#237)
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh authored Aug 2, 2023
1 parent e350fd3 commit 642d0fa
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 13 deletions.
9 changes: 8 additions & 1 deletion olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,10 @@ def path_glob(*paths) -> List[str]:

# Chooses the first path in the arguments that exists.
def path_choose(*paths) -> str:
from .util import is_url

for path in paths:
if Path(path).exists():
if is_url(path) or Path(path).exists():
return path
if validate_paths:
raise FileNotFoundError(", ".join(paths))
Expand Down Expand Up @@ -528,6 +530,11 @@ class TrainConfig(BaseConfig):
The directory to save checkpoints to.
"""

remote_save_folder: Optional[str] = None
"""
A folder in a cloud bucket to upload saved checkpoints to.
"""

save_interval: int = 1000
"""
How often (in terms of batches) to save training state checkpoints that can be used for restarts.
Expand Down
7 changes: 4 additions & 3 deletions olmo/data/memmap_dataset.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from __future__ import annotations

import os
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from cached_path import cached_path
from torch.utils.data import Dataset

from ..aliases import PathOrStr
from ..util import file_size

__all__ = ["MemMapDataset"]

Expand Down Expand Up @@ -70,15 +71,15 @@ def offsets(self) -> List[Tuple[int, int]]:
start_offset = 0
self._mmap_offsets = []
for path in self._memmap_paths:
length = os.stat(path).st_size // (self._item_size * self._chunk_size)
length = file_size(path) // (self._item_size * self._chunk_size)
end_offset = start_offset + length
self._mmap_offsets.append((start_offset, end_offset))
start_offset += length
return self._mmap_offsets

def _read_chunk_from_memmap(self, path: PathOrStr, index: int) -> torch.Tensor:
index_start = index * self._item_size * self._chunk_size
with open(path, "rb") as f:
with open(cached_path(path), "rb") as f:
f.seek(index_start)
buffer = f.read(self._item_size * self._chunk_size)
array = np.frombuffer(buffer, dtype=self.dtype)
Expand Down
42 changes: 34 additions & 8 deletions olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from torch.utils.data import DataLoader
from torchmetrics import MeanMetric

from .aliases import PathOrStr
from .config import CheckpointType, SpeedMonitorConfig, TrainConfig
from .data import IterableDataset
from .eval import Evaluator
Expand All @@ -39,7 +40,9 @@
get_world_size,
move_to_device,
peak_gpu_memory,
resource_path,
syncronize_flag,
upload,
wait_on,
)

Expand Down Expand Up @@ -269,6 +272,18 @@ def save_sharded_checkpoint(self) -> Path:

barrier()

# Upload checkpoint to bucket.
if self.cfg.remote_save_folder is not None:
files_to_upload = [f"rank{get_global_rank()}.pt"]
if get_global_rank() == 0:
files_to_upload.append("config.yaml")
for fname in files_to_upload:
source = checkpoint_dir / fname
target = f"{self.cfg.remote_save_folder}/{checkpoint_dir.name}/{fname}"
log.info(f"Uploading {source} to {target}...")
upload(source, target, save_overwrite=self.cfg.save_overwrite)
barrier()

return checkpoint_dir

def remove_sharded_checkpoint(self, idx: int = 0):
Expand All @@ -281,7 +296,7 @@ def remove_sharded_checkpoint(self, idx: int = 0):
latest_path.unlink()
barrier()

def restore_sharded_checkpoint(self, load_path: Path):
def restore_sharded_checkpoint(self, load_path: PathOrStr):
# Zero-gradients to avoid gathering them.
self.optim.zero_grad(set_to_none=True)

Expand Down Expand Up @@ -312,7 +327,7 @@ def restore_sharded_checkpoint(self, load_path: Path):
# self.optim.load_state_dict(flattened_osd)

# Deserialize state dictionary.
state_dict = torch.load(load_path / f"rank{get_global_rank()}.pt")
state_dict = torch.load(resource_path(load_path, f"rank{get_global_rank()}.pt"))

# Load model and optimizer state.
log.info("Loading model state...")
Expand Down Expand Up @@ -411,6 +426,17 @@ def save_unsharded_checkpoint(self) -> Path:
self.remove_unsharded_checkpoint(0)

barrier()

# Upload checkpoint to bucket.
if self.cfg.remote_save_folder is not None:
if get_global_rank() == 0:
for fname in ["config.yaml", "model.pt", "optim.pt", "other.pt"]:
source = checkpoint_dir / fname
target = f"{self.cfg.remote_save_folder}/{checkpoint_dir.name}/{fname}"
log.info(f"Uploading {source} to {target}...")
upload(source, target, save_overwrite=self.cfg.save_overwrite)
barrier()

return checkpoint_dir

def remove_unsharded_checkpoint(self, idx: int = 0):
Expand All @@ -423,7 +449,7 @@ def remove_unsharded_checkpoint(self, idx: int = 0):
latest_path.unlink()
barrier()

def restore_unsharded_checkpoint(self, load_path: Path):
def restore_unsharded_checkpoint(self, load_path: PathOrStr):
# Zero-gradients to avoid gathering them.
self.optim.zero_grad(set_to_none=True)

Expand All @@ -435,11 +461,11 @@ def restore_unsharded_checkpoint(self, load_path: Path):
):
# Load model state.
log.info("Loading model state...")
self.fsdp_model.load_state_dict(torch.load(load_path / "model.pt"))
self.fsdp_model.load_state_dict(torch.load(resource_path(load_path, "model.pt")))

# Load optimizer state.
log.info("Loading optimizer state...")
optim_state_dict = torch.load(load_path / "optim.pt")
optim_state_dict = torch.load(resource_path(load_path, "optim.pt"))
# NOTE: careful, the order of these arguments has changed since the 2.0 release.
if version.parse(torch.__version__) < version.parse("2.1.0"):
# flattened_osd = FSDP.optim_state_dict_to_load(optim_state["optim"], self.fsdp_model, self.optim) # type: ignore
Expand All @@ -452,7 +478,7 @@ def restore_unsharded_checkpoint(self, load_path: Path):
del flattened_osd

# Load other state.
other_state_dict = torch.load(load_path / "other.pt")
other_state_dict = torch.load(resource_path(load_path, "other.pt"))
self.load_non_tensor_state_dict(other_state_dict)

barrier()
Expand All @@ -465,9 +491,9 @@ def save_checkpoint(self, checkpoint_type: CheckpointType = CheckpointType.shard
else:
raise NotImplementedError(checkpoint_type)

def restore_checkpoint(self, load_path: Path, checkpoint_type: Optional[CheckpointType] = None):
def restore_checkpoint(self, load_path: PathOrStr, checkpoint_type: Optional[CheckpointType] = None):
if checkpoint_type == CheckpointType.unsharded or (
checkpoint_type is None and load_path.name.endswith("-unsharded")
checkpoint_type is None and str(load_path).endswith("-unsharded")
):
self.restore_unsharded_checkpoint(load_path)
elif checkpoint_type == CheckpointType.sharded or checkpoint_type is None:
Expand Down
105 changes: 105 additions & 0 deletions olmo/util.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import logging
import os
import re
import socket
import sys
import time
import warnings
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, Dict, Optional, TypeVar, Union

import rich
Expand All @@ -15,6 +17,7 @@
from rich.text import Text
from rich.traceback import Traceback

from .aliases import PathOrStr
from .config import LogFilterType
from .exceptions import OlmoCliError, OlmoError

Expand Down Expand Up @@ -349,3 +352,105 @@ def wait_on(condition: Callable[[], bool], description: str, timeout: float = 10
time.sleep(0.5)
if time.monotonic() - start_time > timeout:
raise TimeoutError(f"{description} timed out")


def is_url(path: PathOrStr) -> bool:
return re.match(r"[a-z0-9]+://.*", str(path)) is not None


def resource_path(folder: PathOrStr, fname: str) -> PathOrStr:
if is_url(folder):
from cached_path import cached_path

return cached_path(f"{folder}/{fname}")
else:
return Path(folder) / fname


def file_size(path: PathOrStr) -> int:
"""
Get the size of a local or remote file in bytes.
"""
if is_url(path):
from urllib.parse import urlparse

parsed = urlparse(str(path))
if parsed.scheme == "gs":
return _gcs_file_size(parsed.netloc, parsed.path)
elif parsed.scheme == "s3":
return _s3_file_size(parsed.netloc, parsed.path)
elif parsed.scheme == "file":
return file_size(str(path).replace("file://", "", 1))
else:
raise NotImplementedError(f"file size not implemented for '{parsed.scheme}' files")
else:
return os.stat(path).st_size


def upload(source: PathOrStr, target: str, save_overwrite: bool = False):
"""Upload source file to a target location on GCS or S3."""
from urllib.parse import urlparse

source = Path(source)
assert source.is_file()
parsed = urlparse(target)
if parsed.scheme == "gs":
_gcs_upload(source, parsed.netloc, parsed.path, save_overwrite=save_overwrite)
elif parsed.scheme == "s3":
_s3_upload(source, parsed.netloc, parsed.path, save_overwrite=save_overwrite)
else:
raise NotImplementedError(f"Upload not implemented for '{parsed.scheme}' scheme")


def _gcs_file_size(bucket_name: str, key: str) -> int:
from google.api_core.exceptions import NotFound
from google.cloud import storage as gcs

storage_client = gcs.Client()
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(key)
try:
blob.reload()
except NotFound:
raise FileNotFoundError("gs://{bucket_name}/{key}")
assert blob.size is not None
return blob.size


def _gcs_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool = False):
from google.cloud import storage as gcs

storage_client = gcs.Client()
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(key)
if not save_overwrite and blob.exists():
raise FileExistsError(f"gs://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it.")
blob.upload_from_filename(source)


def _s3_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool = False):
import boto3
from botocore.exceptions import ClientError

s3_client = boto3.client("s3")
if not save_overwrite:
try:
s3_client.head_object(Bucket=bucket_name, Key=key)
raise FileExistsError(f"gs://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it.")
except ClientError as e:
if int(e.response["Error"]["Code"]) != 404:
raise
s3_client.upload_file(source, bucket_name, key)


def _s3_file_size(bucket_name: str, key: str) -> int:
import boto3
from botocore.exceptions import ClientError

s3_client = boto3.client("s3")
try:
return s3_client.head_object(Bucket=bucket_name, Key=key)["ContentLength"]
except ClientError as e:
if int(e.response["Error"]["Code"]) != 404:
raise
raise FileNotFoundError("s3://{bucket_name}/{key}")
2 changes: 1 addition & 1 deletion scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def main(cfg: TrainConfig) -> None:

if cfg.load_path is not None:
log.info(f"Loading checkpoint from {cfg.load_path}...")
trainer.restore_checkpoint(Path(cfg.load_path))
trainer.restore_checkpoint(cfg.load_path)
log.info("Checkpoint successfully loaded")

if cfg.force_save_unsharded:
Expand Down

0 comments on commit 642d0fa

Please sign in to comment.