diff --git a/torchtune/models/convert_weights.py b/torchtune/models/convert_weights.py index 4c2837ac11..31a93b2b32 100644 --- a/torchtune/models/convert_weights.py +++ b/torchtune/models/convert_weights.py @@ -73,6 +73,22 @@ def get_mapped_key(key: str, mapping_dict: dict[str, str]) -> str: return new_key +def is_in_meta_format(state_dict: dict[str, torch.Tensor]) -> bool: + """ + Check whether the state dict is in Meta's format by checking the presence + of unique keys only available in META format. + """ + unique_meta_keys = {k for k, v in _FROM_META.items() if k != v} + for key in state_dict.keys(): + if key not in ["rope.freqs"]: # Skip loading the position embeddings + if any(k.isdigit() for k in key.split(".")): + # Replace layer number with "{}" to create key for lookup + key = re.sub(r"(\.\d+)", ".{}", key) + if key in unique_meta_keys: + return True + return False + + def meta_to_tune(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ Convert a state dict from Meta's format to torchtune's format. State dicts