From 8d035b15a4fdab3dfd36e0f11a5921371aae84ec Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 4 Apr 2024 16:12:59 -0700 Subject: [PATCH] add unshard method for OlmoCoreCheckpointer --- olmo/checkpoint.py | 23 +++++++++++++++++++++++ scripts/unshard.py | 3 +++ 2 files changed, 26 insertions(+) diff --git a/olmo/checkpoint.py b/olmo/checkpoint.py index ad7079c13..772549d30 100644 --- a/olmo/checkpoint.py +++ b/olmo/checkpoint.py @@ -1738,6 +1738,29 @@ def restore_checkpoint( barrier() return trainer_state + def unshard_checkpoint( + self, + load_path: PathOrStr, + *, + local_cache: Optional[PathOrStr] = None, + load_optimizer_state: bool = True, + load_trainer_state: bool = True, + device: Optional[torch.device] = None, + ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]: + from olmo_core.distributed.checkpoint import ( # type: ignore + unshard_model_state, + unshard_optim_state, + ) + + model_state = unshard_model_state(load_path, device=device) + optim_state: Optional[Dict[str, Any]] = None + train_state: Optional[Dict[str, Any]] = None + if load_optimizer_state: + optim_state = cast(Dict[str, Any], unshard_optim_state(load_path, device=device)) + if load_trainer_state: + train_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache) + return model_state, optim_state, train_state + def build_sharded_checkpointer( cfg: TrainConfig, *, name: Optional[ShardedCheckpointerType] = None diff --git a/scripts/unshard.py b/scripts/unshard.py index 72567b317..e327c85ed 100644 --- a/scripts/unshard.py +++ b/scripts/unshard.py @@ -8,6 +8,7 @@ from olmo.checkpoint import ( Checkpointer, LocalShardedCheckpointer, + OlmoCoreCheckpointer, TorchLegacyShardedCheckpointer, ) from olmo.config import ShardedCheckpointerType, TrainConfig @@ -35,6 +36,8 @@ def main( checkpointer = TorchLegacyShardedCheckpointer(config) elif sharded_checkpoint_type == ShardedCheckpointerType.local: checkpointer = LocalShardedCheckpointer(config) + elif sharded_checkpoint_type == ShardedCheckpointerType.olmo_core: + checkpointer = OlmoCoreCheckpointer(config) else: raise NotImplementedError(sharded_checkpoint_type)