Skip to content

Commit

Permalink
add unshard method for OlmoCoreCheckpointer
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Apr 4, 2024
1 parent 4fbabe6 commit 8d035b1
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
23 changes: 23 additions & 0 deletions olmo/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1738,6 +1738,29 @@ def restore_checkpoint(
barrier()
return trainer_state

def unshard_checkpoint(
self,
load_path: PathOrStr,
*,
local_cache: Optional[PathOrStr] = None,
load_optimizer_state: bool = True,
load_trainer_state: bool = True,
device: Optional[torch.device] = None,
) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
from olmo_core.distributed.checkpoint import ( # type: ignore
unshard_model_state,
unshard_optim_state,
)

model_state = unshard_model_state(load_path, device=device)
optim_state: Optional[Dict[str, Any]] = None
train_state: Optional[Dict[str, Any]] = None
if load_optimizer_state:
optim_state = cast(Dict[str, Any], unshard_optim_state(load_path, device=device))
if load_trainer_state:
train_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache)
return model_state, optim_state, train_state


def build_sharded_checkpointer(
cfg: TrainConfig, *, name: Optional[ShardedCheckpointerType] = None
Expand Down
3 changes: 3 additions & 0 deletions scripts/unshard.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from olmo.checkpoint import (
Checkpointer,
LocalShardedCheckpointer,
OlmoCoreCheckpointer,
TorchLegacyShardedCheckpointer,
)
from olmo.config import ShardedCheckpointerType, TrainConfig
Expand Down Expand Up @@ -35,6 +36,8 @@ def main(
checkpointer = TorchLegacyShardedCheckpointer(config)
elif sharded_checkpoint_type == ShardedCheckpointerType.local:
checkpointer = LocalShardedCheckpointer(config)
elif sharded_checkpoint_type == ShardedCheckpointerType.olmo_core:
checkpointer = OlmoCoreCheckpointer(config)
else:
raise NotImplementedError(sharded_checkpoint_type)

Expand Down

0 comments on commit 8d035b1

Please sign in to comment.