diff --git a/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py b/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py index 359fc43cab..1c9febc732 100644 --- a/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py +++ b/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py @@ -4,15 +4,25 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import json import os import shutil from pathlib import Path from typing import Any, Dict, List import torch +from safetensors.torch import save_file from torchtune import training from torchtune.models import convert_weights -from torchtune.training.checkpointing._utils import ModelType, safe_torch_load +from torchtune.training.checkpointing._utils import ( + ADAPTER_CONFIG_FNAME, + ADAPTER_MODEL_FNAME, + copy_files, + ModelType, + REPO_ID_FNAME, + safe_torch_load, + SUFFIXES_TO_NOT_COPY, +) from torchtune.utils._logging import get_logger logger = get_logger("DEBUG") @@ -81,83 +91,175 @@ def save_checkpoint( state_dict: Dict[str, Any], epoch: int, adapter_only: bool = False, + checkpoint_format: str = "meta", ) -> str: model_file_path = ( Path(self._output_dir) / f"{self._model_id}-{self._training_algorithm}-{epoch}" ) + if checkpoint_format == "meta": + model_file_path.mkdir(parents=True, exist_ok=True) - model_file_path.mkdir(parents=True, exist_ok=True) + # copy the related files for inference + source_path = Path.joinpath(self._checkpoint_dir, "params.json") + if source_path.exists(): + shutil.copy( + source_path, + Path.joinpath(model_file_path, "params.json"), + ) + source_path = Path.joinpath(self._checkpoint_dir, "tokenizer.model") + if source_path.exists(): + shutil.copy( + source_path, + Path.joinpath(model_file_path, "tokenizer.model"), + ) + source_path = Path.joinpath(self._checkpoint_dir, "orig_params.json") + if source_path.exists(): + shutil.copy( + source_path, + Path.joinpath(model_file_path, "orig_params.json"), + ) - # copy the related files for inference - source_path = Path.joinpath(self._checkpoint_dir, "params.json") - if source_path.exists(): - shutil.copy( - source_path, - Path.joinpath(model_file_path, "params.json"), - ) - source_path = Path.joinpath(self._checkpoint_dir, "tokenizer.model") - if source_path.exists(): - shutil.copy( - source_path, - Path.joinpath(model_file_path, "tokenizer.model"), - ) - source_path = Path.joinpath(self._checkpoint_dir, "orig_params.json") - if source_path.exists(): - shutil.copy( - source_path, - Path.joinpath(model_file_path, "orig_params.json"), + if not adapter_only: + model_state_dict = state_dict[training.MODEL_KEY] + if self._model_type == ModelType.LLAMA3_VISION: + from torchtune.models.llama3_2_vision._convert_weights import ( + llama3_vision_tune_to_meta, + ) + + state_dict[training.MODEL_KEY] = llama3_vision_tune_to_meta( + model_state_dict + ) + else: + # llama3_2 has tied weights, so we need to add the output.weight key + if ( + self._model_type == ModelType.LLAMA3_2 + and "output.weight" not in model_state_dict + ): + model_state_dict["output.weight"] = model_state_dict[ + "tok_embeddings.weight" + ] + + state_dict[training.MODEL_KEY] = convert_weights.tune_to_meta( + model_state_dict + ) + + model_file_name = Path.joinpath(model_file_path, "consolidated.00.pth") + + torch.save(state_dict[training.MODEL_KEY], model_file_name) + logger.info( + "Model checkpoint of size " + f"{os.path.getsize(model_file_name) / 1000**3:.2f} GB " + f"saved to {model_file_name}" + ) + + if training.ADAPTER_KEY in state_dict: + adapter_file_path = model_file_path / "adapter" + adapter_file_path.mkdir(parents=True, exist_ok=True) + adapter_file_name = Path.joinpath(adapter_file_path, "adapter.pth") + torch.save(state_dict[training.ADAPTER_KEY], adapter_file_name) + logger.info( + "Adapter checkpoint of size " + f"{os.path.getsize(adapter_file_name) / 1000**3:.2f} GB " + f"saved to {adapter_file_name}" + ) + + elif adapter_only: + raise ValueError( + "Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights." + ) + elif checkpoint_format == "hf": + # Note: for saving hugging face format checkpoints, we only suppport saving adapter weights now + + # the config.json file contains model params needed for state dict conversion + config = json.loads( + Path.joinpath(self._checkpoint_dir.parent, "config.json").read_text() ) - if not adapter_only: - model_state_dict = state_dict[training.MODEL_KEY] - if self._model_type == ModelType.LLAMA3_VISION: - from torchtune.models.llama3_2_vision._convert_weights import ( - llama3_vision_tune_to_meta, + # repo_id is necessary for when saving an adapter config, so its compatible with HF. + # This json file is produced and saved in the download step. + # contents are {"repo_id": "some_model/some_model_version"} + repo_id_path = Path.joinpath( + self._checkpoint_dir.parent, REPO_ID_FNAME + ).with_suffix(".json") + self.repo_id = None + if repo_id_path.exists(): + with open(repo_id_path, "r") as json_file: + data = json.load(json_file) + self.repo_id = data.get("repo_id") + + if training.ADAPTER_KEY in state_dict: + # TODO: saving it "as is" is a requirement because, if we only save with + # convert_weights.tune_to_peft_adapter_weights, we do NOT have a fn + # convert_weights.peft_to_tune. The .pt format is not needed, but + # it is an easy way to distinguish the adapters. Ideally we should save only one. + output_path = Path.joinpath( + model_file_path, ADAPTER_MODEL_FNAME + ).with_suffix(".pt") + output_path.parent.mkdir(parents=True, exist_ok=True) + torch.save(state_dict[training.ADAPTER_KEY], output_path) + logger.info( + "Adapter checkpoint of size " + f"{os.path.getsize(output_path) / 1024**3:.2f} GiB " + f"saved to {output_path}" ) - state_dict[training.MODEL_KEY] = llama3_vision_tune_to_meta( - model_state_dict + state_dict[training.ADAPTER_KEY] = ( + convert_weights.tune_to_peft_adapter_weights( + state_dict[training.ADAPTER_KEY], + num_heads=config["num_attention_heads"], + num_kv_heads=config["num_key_value_heads"], + dim=config["hidden_size"], + head_dim=config.get("head_dim", None), + ) + ) + output_path = Path.joinpath( + model_file_path, "adapter", ADAPTER_MODEL_FNAME + ) + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path = output_path.with_suffix(".safetensors") + save_file( + state_dict[training.ADAPTER_KEY], + output_path, + metadata={"format": "pt"}, + ) + logger.info( + "Adapter checkpoint of size " + f"{os.path.getsize(output_path) / 1024**3:.2f} GiB " + f"saved to {output_path}" ) else: - # llama3_2 has tied weights, so we need to add the output.weight key - if ( - self._model_type == ModelType.LLAMA3_2 - and "output.weight" not in model_state_dict - ): - model_state_dict["output.weight"] = model_state_dict[ - "tok_embeddings.weight" - ] - - state_dict[training.MODEL_KEY] = convert_weights.tune_to_meta( - model_state_dict + raise ValueError( + "Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights." ) - model_file_name = Path.joinpath(model_file_path, "consolidated.00.pth") - - torch.save(state_dict[training.MODEL_KEY], model_file_name) - logger.info( - "Model checkpoint of size " - f"{os.path.getsize(model_file_name) / 1000**3:.2f} GB " - f"saved to {model_file_name}" - ) + if training.ADAPTER_CONFIG in state_dict: + state_dict[training.ADAPTER_CONFIG] = ( + convert_weights.tune_to_peft_adapter_config( + adapter_config=state_dict[training.ADAPTER_CONFIG], + base_model_name_or_path=self.repo_id, + ) + ) - if training.ADAPTER_KEY in state_dict: - adapter_file_path = model_file_path / "adapter" - adapter_file_path.mkdir(parents=True, exist_ok=True) - adapter_file_name = Path.joinpath(adapter_file_path, "adapter.pth") - torch.save(state_dict[training.ADAPTER_KEY], adapter_file_name) - logger.info( - "Adapter checkpoint of size " - f"{os.path.getsize(adapter_file_name) / 1000**3:.2f} GB " - f"saved to {adapter_file_name}" - ) + output_path = Path.joinpath( + model_file_path, "adapter", ADAPTER_CONFIG_FNAME + ).with_suffix(".json") + with open(output_path, "w") as f: + json.dump(state_dict[training.ADAPTER_CONFIG], f) + logger.info( + "Adapter checkpoint of size " + f"{os.path.getsize(output_path) / 1024**3:.2f} GiB " + f"saved to {output_path}" + ) - elif adapter_only: - raise ValueError( - "Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights." + # Save all files in ckpt_dir, except model weights and mapping, to output_dir/epoch_{epoch} + # So its easy to run inference with the model using this epoch's checkpoint + copy_files( + self._checkpoint_dir.parent, + model_file_path, + ignore_suffixes=SUFFIXES_TO_NOT_COPY, ) - - print("model_file_path", str(model_file_path)) + else: + raise ValueError(f"Unsupported checkpoint format: {format}") return str(model_file_path) diff --git a/llama_stack/providers/inline/post_training/torchtune/config.py b/llama_stack/providers/inline/post_training/torchtune/config.py index 3ffa55c707..34a48589d4 100644 --- a/llama_stack/providers/inline/post_training/torchtune/config.py +++ b/llama_stack/providers/inline/post_training/torchtune/config.py @@ -11,3 +11,4 @@ class TorchtunePostTrainingConfig(BaseModel): torch_seed: Optional[int] = None + checkpoint_format: Optional[str] = "meta" diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index 80e206ebbb..ac61bc6ccc 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -129,6 +129,7 @@ def model_checkpoint_dir(model) -> str: self.checkpoint_dir = model_checkpoint_dir(model) self._output_dir = str(DEFAULT_CHECKPOINT_DIR) + self._checkpoint_format = config.checkpoint_format self.seed = training.set_seed(seed=config.torch_seed) self.epochs_run = 0 @@ -444,6 +445,7 @@ async def save_checkpoint(self, epoch: int) -> str: return self._checkpointer.save_checkpoint( ckpt_dict, epoch=epoch, + checkpoint_format=self._checkpoint_format, ) async def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: @@ -488,7 +490,7 @@ async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]: # Update the sampler to ensure data is correctly shuffled across epochs # in case shuffle is True metric_logger = DiskLogger( - log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}" + log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}/log" ) self._training_sampler.set_epoch(curr_epoch) loss_to_log = 0.0