Skip to content

Commit

Permalink
make loading backwards compat
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Aug 4, 2023
1 parent 209a268 commit b293a09
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
15 changes: 9 additions & 6 deletions olmo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,16 +986,19 @@ def from_checkpoint(cls, checkpoint_dir: PathOrStr, device: str = "cpu") -> Olmo

def _make_state_dict_compatible(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
# For backwards compatibility prior to fixing https://github.com/allenai/LLM/issues/222
prefix = ""
if next(iter(state_dict.keys())).startswith((fsdp_prefix := "_fsdp_wrapped_module.")):
prefix = fsdp_prefix
if self.config.block_type == BlockType.sequential:
for block_idx in range(self.config.n_layers):
norm_w_key = f"transformer.blocks.{block_idx}.norm.weight"
norm_b_key = f"transformer.blocks.{block_idx}.norm.bias"
norm_w_key = f"{prefix}transformer.blocks.{block_idx}.norm.weight"
norm_b_key = f"{prefix}transformer.blocks.{block_idx}.norm.bias"
if norm_w_key in state_dict:
norm_w = state_dict.pop(norm_w_key)
state_dict[f"transformer.blocks.{block_idx}.attn_norm.weight"] = norm_w
state_dict[f"transformer.blocks.{block_idx}.ff_norm.weight"] = norm_w.clone()
state_dict[f"{prefix}transformer.blocks.{block_idx}.attn_norm.weight"] = norm_w
state_dict[f"{prefix}transformer.blocks.{block_idx}.ff_norm.weight"] = norm_w.clone()
if norm_b_key in state_dict:
norm_b = state_dict.pop(norm_b_key)
state_dict[f"transformer.blocks.{block_idx}.attn_norm.bias"] = norm_b
state_dict[f"transformer.blocks.{block_idx}.ff_norm.bias"] = norm_b.clone()
state_dict[f"{prefix}transformer.blocks.{block_idx}.attn_norm.bias"] = norm_b
state_dict[f"{prefix}transformer.blocks.{block_idx}.ff_norm.bias"] = norm_b.clone()
return state_dict
4 changes: 3 additions & 1 deletion olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,9 @@ def restore_unsharded_checkpoint(self, load_path: PathOrStr):
):
# Load model state.
log.info("Loading model state...")
self.fsdp_model.load_state_dict(torch.load(resource_path(load_path, "model.pt")))
self.fsdp_model.load_state_dict(
self.model._make_state_dict_compatible(torch.load(resource_path(load_path, "model.pt")))
)

# Load optimizer state.
log.info("Loading optimizer state...")
Expand Down

0 comments on commit b293a09

Please sign in to comment.