Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds a new experimental sharded checkpointer from OLMo-core #532

Merged
merged 6 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions olmo/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1657,6 +1657,65 @@ def _gather_state_dict_paths(
return [results[rank] for rank in range(world_size)]


class OlmoCoreCheckpointer(Checkpointer):
def save_checkpoint(
self,
dir: PathOrStr,
fsdp_model: FSDP,
optim: Optimizer,
trainer_state: Dict[str, Any],
*,
upload_to: Optional[str] = None,
) -> None:
from olmo_core.distributed.checkpoint import ( # type: ignore
save_model_and_optim_state,
)

with self._temporary_wd(dir) as checkpoint_dir:
log.info("Saving model and optim state...")
save_model_and_optim_state(checkpoint_dir, fsdp_model, optim, save_overwrite=self.cfg.save_overwrite)
if upload_to is not None and get_fs_local_rank() == 0:
for path in Path(checkpoint_dir).glob("**/*"):
if not path.is_file():
continue
upload_target = f"{upload_to.rstrip('/')}/{path.relative_to(checkpoint_dir)}"
log.info(f"Uploading {path} to {upload_target}...")
upload(path, upload_target, save_overwrite=self.cfg.save_overwrite)

log.info("Saving trainer state...")
save_state_dict(
checkpoint_dir,
f"train/rank{get_global_rank()}.pt",
trainer_state,
save_overwrite=self.cfg.save_overwrite,
upload_to=upload_to,
)

self._save_config(checkpoint_dir, upload_to=upload_to)

def restore_checkpoint(
self,
load_path: PathOrStr,
fsdp_model: FSDP,
optim: Optimizer,
*,
local_cache: Optional[PathOrStr] = None,
load_optimizer_state: bool = True,
) -> Dict[str, Any]:
from olmo_core.distributed.checkpoint import ( # type: ignore
load_model_and_optim_state,
)

log.info("Loading model and optim state...")
load_model_and_optim_state(load_path, fsdp_model, optim if load_optimizer_state else None)

log.info("Loading trainer state...")
trainer_state = load_state_dict(load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache)

barrier()
return trainer_state


def build_sharded_checkpointer(
cfg: TrainConfig, *, name: Optional[ShardedCheckpointerType] = None
) -> Checkpointer:
Expand All @@ -1667,5 +1726,7 @@ def build_sharded_checkpointer(
return TorchLegacyShardedCheckpointer(cfg)
elif name == ShardedCheckpointerType.local:
return LocalShardedCheckpointer(cfg)
elif name == ShardedCheckpointerType.olmo_core:
return OlmoCoreCheckpointer(cfg)
else:
raise NotImplementedError(name)
1 change: 1 addition & 0 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,7 @@ class ShardedCheckpointerType(StrEnum):
torch_new = "torch_new"
torch_legacy = "torch_legacy"
local = "local"
olmo_core = "olmo_core"


class ActivationCheckpointingStrategy(StrEnum):
Expand Down
Loading