Skip to content

Commit 053b938

Browse files
lijia19facebook-github-bot
authored andcommitted
validate checkpoint is consistent with meta_to_tune flag (#2736)
Summary: add a flag that do validation when load_checkpoint passed unexpected meta_to_tune flag. i.e. if this flag is true but checkpoint is not in meta format, or if this flag is flase but checkpoint is in meta format Do validation early to avoid unexpected error later. Differential Revision: D74784778
1 parent baaaf21 commit 053b938

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

torchtune/models/convert_weights.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,22 @@ def get_mapped_key(key: str, mapping_dict: Dict[str, str]) -> str:
7373
return new_key
7474

7575

76+
def is_in_meta_format(state_dict: Dict[str, torch.Tensor]) -> bool:
77+
"""
78+
Check whether the state dict is in Meta's format by checking the presence
79+
of unique keys only available in META format.
80+
"""
81+
unique_meta_keys = {k for k, v in _FROM_META.items() if k != v}
82+
for key in state_dict.keys():
83+
if key not in ["rope.freqs"]: # Skip loading the position embeddings
84+
if any(k.isdigit() for k in key.split(".")):
85+
# Replace layer number with "{}" to create key for lookup
86+
key = re.sub(r"(\.\d+)", ".{}", key)
87+
if key in unique_meta_keys:
88+
return True
89+
return False
90+
91+
7692
def meta_to_tune(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
7793
"""
7894
Convert a state dict from Meta's format to torchtune's format. State dicts

0 commit comments

Comments
 (0)