diff --git a/olmo/checkpoint.py b/olmo/checkpoint.py index 1dd355a36..251801100 100644 --- a/olmo/checkpoint.py +++ b/olmo/checkpoint.py @@ -1657,6 +1657,65 @@ def _gather_state_dict_paths( return [results[rank] for rank in range(world_size)] +class OlmoCoreCheckpointer(Checkpointer): + def save_checkpoint( + self, + dir: PathOrStr, + fsdp_model: FSDP, + optim: Optimizer, + trainer_state: Dict[str, Any], + *, + upload_to: Optional[str] = None, + ) -> None: + from olmo_core.distributed.checkpoint import ( # type: ignore + save_model_and_optim_state, + ) + + with self._temporary_wd(dir) as checkpoint_dir: + log.info("Saving model and optim state...") + save_model_and_optim_state(checkpoint_dir, fsdp_model, optim, save_overwrite=self.cfg.save_overwrite) + if upload_to is not None and get_fs_local_rank() == 0: + for path in Path(checkpoint_dir).glob("**/*"): + if not path.is_file(): + continue + upload_target = f"{upload_to.rstrip('/')}/{path.relative_to(checkpoint_dir)}" + log.info(f"Uploading {path} to {upload_target}...") + upload(path, upload_target, save_overwrite=self.cfg.save_overwrite) + + log.info("Saving trainer state...") + save_state_dict( + checkpoint_dir, + f"train/rank{get_global_rank()}.pt", + trainer_state, + save_overwrite=self.cfg.save_overwrite, + upload_to=upload_to, + ) + + self._save_config(checkpoint_dir, upload_to=upload_to) + + def restore_checkpoint( + self, + load_path: PathOrStr, + fsdp_model: FSDP, + optim: Optimizer, + *, + local_cache: Optional[PathOrStr] = None, + load_optimizer_state: bool = True, + ) -> Dict[str, Any]: + from olmo_core.distributed.checkpoint import ( # type: ignore + load_model_and_optim_state, + ) + + log.info("Loading model and optim state...") + load_model_and_optim_state(load_path, fsdp_model, optim if load_optimizer_state else None) + + log.info("Loading trainer state...") + trainer_state = load_state_dict(load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache) + + barrier() + return trainer_state + + def build_sharded_checkpointer( cfg: TrainConfig, *, name: Optional[ShardedCheckpointerType] = None ) -> Checkpointer: @@ -1667,5 +1726,7 @@ def build_sharded_checkpointer( return TorchLegacyShardedCheckpointer(cfg) elif name == ShardedCheckpointerType.local: return LocalShardedCheckpointer(cfg) + elif name == ShardedCheckpointerType.olmo_core: + return OlmoCoreCheckpointer(cfg) else: raise NotImplementedError(name) diff --git a/olmo/config.py b/olmo/config.py index 6244994a3..042c704ce 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -692,6 +692,7 @@ class ShardedCheckpointerType(StrEnum): torch_new = "torch_new" torch_legacy = "torch_legacy" local = "local" + olmo_core = "olmo_core" class ActivationCheckpointingStrategy(StrEnum):