Skip to content

Commit 49128a6

Browse files
lijia19facebook-github-bot
authored andcommitted
validate checkpoint is consistent with meta_to_tune flag
Summary: add a flag that do validation when load_checkpoint passed unexpected meta_to_tune flag. i.e. if this flag is true but checkpoint is not in meta format, or if this flag is flase but checkpoint is in meta format Do validation early to avoid unexpected error later. Differential Revision: D74784778
1 parent baaaf21 commit 49128a6

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

torchtune/models/convert_weights.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,21 @@ def get_mapped_key(key: str, mapping_dict: Dict[str, str]) -> str:
7272

7373
return new_key
7474

75+
def is_in_meta_format(state_dict: Dict[str, torch.Tensor]) -> bool:
76+
"""
77+
Check whether the state dict is in Meta's format by checking the presence
78+
of unique keys only available in META format.
79+
"""
80+
unique_meta_keys = {k for k, v in _FROM_META.items() if k != v}
81+
for key in state_dict.keys():
82+
if key not in ["rope.freqs"]: # Skip loading the position embeddings
83+
if any(k.isdigit() for k in key.split(".")):
84+
# Replace layer number with "{}" to create key for lookup
85+
key = re.sub(r"(\.\d+)", ".{}", key)
86+
if key in unique_meta_keys:
87+
return True
88+
return False
89+
7590

7691
def meta_to_tune(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
7792
"""

0 commit comments

Comments
 (0)