Skip to content

Commit

Permalink
Guess at the world size
Browse files Browse the repository at this point in the history
  • Loading branch information
dirkgr committed Oct 25, 2023
1 parent bb8e2f6 commit b764c37
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions olmo/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,13 +1037,18 @@ def unshard_checkpoint(
device: Optional[torch.device] = None,
) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]]]:
device = device or torch.device("cpu")
metadata = self._load_metadata(load_path, local_cache=local_cache)
try:
metadata = self._load_metadata(load_path, local_cache=local_cache)
world_size = metadata.world_size
except FileNotFoundError:
assert isinstance(
load_path, Path
), "Automatically detecting the world size requires the checkpoint to be local."
world_size = sum(1 for _ in (load_path / "train").glob("rank*.pt"))

# Gather paths model state, potentially downloading them.
log.info("Gathering model state dicts...")
model_state_paths = self._gather_state_dict_paths(
load_path, "model", metadata.world_size, local_cache=local_cache
)
model_state_paths = self._gather_state_dict_paths(load_path, "model", world_size, local_cache=local_cache)

# Load model state dicts one-by-one, materializing and populating the full parameters as we go.
log.info("Materializing full parameters...")
Expand Down Expand Up @@ -1081,9 +1086,7 @@ def unshard_checkpoint(
return full_model_state, None

log.info("Gathering optim state dicts...")
optim_state_paths = self._gather_state_dict_paths(
load_path, "optim", metadata.world_size, local_cache=local_cache
)
optim_state_paths = self._gather_state_dict_paths(load_path, "optim", world_size, local_cache=local_cache)

log.info("Materializing full optim state...")
full_optim_state: Dict[str, Any] = {"state": defaultdict(dict)}
Expand Down

0 comments on commit b764c37

Please sign in to comment.