From b293a09e03773ece3719857b8d06ece1e31abcef Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 4 Aug 2023 14:09:12 -0700 Subject: [PATCH] make loading backwards compat --- olmo/model.py | 15 +++++++++------ olmo/train.py | 4 +++- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/olmo/model.py b/olmo/model.py index 63a247917..1bf9cc488 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -986,16 +986,19 @@ def from_checkpoint(cls, checkpoint_dir: PathOrStr, device: str = "cpu") -> Olmo def _make_state_dict_compatible(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: # For backwards compatibility prior to fixing https://github.com/allenai/LLM/issues/222 + prefix = "" + if next(iter(state_dict.keys())).startswith((fsdp_prefix := "_fsdp_wrapped_module.")): + prefix = fsdp_prefix if self.config.block_type == BlockType.sequential: for block_idx in range(self.config.n_layers): - norm_w_key = f"transformer.blocks.{block_idx}.norm.weight" - norm_b_key = f"transformer.blocks.{block_idx}.norm.bias" + norm_w_key = f"{prefix}transformer.blocks.{block_idx}.norm.weight" + norm_b_key = f"{prefix}transformer.blocks.{block_idx}.norm.bias" if norm_w_key in state_dict: norm_w = state_dict.pop(norm_w_key) - state_dict[f"transformer.blocks.{block_idx}.attn_norm.weight"] = norm_w - state_dict[f"transformer.blocks.{block_idx}.ff_norm.weight"] = norm_w.clone() + state_dict[f"{prefix}transformer.blocks.{block_idx}.attn_norm.weight"] = norm_w + state_dict[f"{prefix}transformer.blocks.{block_idx}.ff_norm.weight"] = norm_w.clone() if norm_b_key in state_dict: norm_b = state_dict.pop(norm_b_key) - state_dict[f"transformer.blocks.{block_idx}.attn_norm.bias"] = norm_b - state_dict[f"transformer.blocks.{block_idx}.ff_norm.bias"] = norm_b.clone() + state_dict[f"{prefix}transformer.blocks.{block_idx}.attn_norm.bias"] = norm_b + state_dict[f"{prefix}transformer.blocks.{block_idx}.ff_norm.bias"] = norm_b.clone() return state_dict diff --git a/olmo/train.py b/olmo/train.py index da2aa16ff..537be39db 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -466,7 +466,9 @@ def restore_unsharded_checkpoint(self, load_path: PathOrStr): ): # Load model state. log.info("Loading model state...") - self.fsdp_model.load_state_dict(torch.load(resource_path(load_path, "model.pt"))) + self.fsdp_model.load_state_dict( + self.model._make_state_dict_compatible(torch.load(resource_path(load_path, "model.pt"))) + ) # Load optimizer state. log.info("Loading optimizer state...")