Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[post training] support save hf safetensor format checkpoint #845

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,25 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import json
import os
import shutil
from pathlib import Path
from typing import Any, Dict, List

import torch
from safetensors.torch import save_file
from torchtune import training
from torchtune.models import convert_weights
from torchtune.training.checkpointing._utils import ModelType, safe_torch_load
from torchtune.training.checkpointing._utils import (
ADAPTER_CONFIG_FNAME,
ADAPTER_MODEL_FNAME,
copy_files,
ModelType,
REPO_ID_FNAME,
safe_torch_load,
SUFFIXES_TO_NOT_COPY,
)
from torchtune.utils._logging import get_logger

logger = get_logger("DEBUG")
Expand Down Expand Up @@ -81,83 +91,175 @@ def save_checkpoint(
state_dict: Dict[str, Any],
epoch: int,
adapter_only: bool = False,
checkpoint_format: str = "meta",
) -> str:
model_file_path = (
Path(self._output_dir)
/ f"{self._model_id}-{self._training_algorithm}-{epoch}"
)
if checkpoint_format == "meta":
model_file_path.mkdir(parents=True, exist_ok=True)

model_file_path.mkdir(parents=True, exist_ok=True)
# copy the related files for inference
source_path = Path.joinpath(self._checkpoint_dir, "params.json")
if source_path.exists():
shutil.copy(
source_path,
Path.joinpath(model_file_path, "params.json"),
)
source_path = Path.joinpath(self._checkpoint_dir, "tokenizer.model")
if source_path.exists():
shutil.copy(
source_path,
Path.joinpath(model_file_path, "tokenizer.model"),
)
source_path = Path.joinpath(self._checkpoint_dir, "orig_params.json")
if source_path.exists():
shutil.copy(
source_path,
Path.joinpath(model_file_path, "orig_params.json"),
)

# copy the related files for inference
source_path = Path.joinpath(self._checkpoint_dir, "params.json")
if source_path.exists():
shutil.copy(
source_path,
Path.joinpath(model_file_path, "params.json"),
)
source_path = Path.joinpath(self._checkpoint_dir, "tokenizer.model")
if source_path.exists():
shutil.copy(
source_path,
Path.joinpath(model_file_path, "tokenizer.model"),
)
source_path = Path.joinpath(self._checkpoint_dir, "orig_params.json")
if source_path.exists():
shutil.copy(
source_path,
Path.joinpath(model_file_path, "orig_params.json"),
if not adapter_only:
model_state_dict = state_dict[training.MODEL_KEY]
if self._model_type == ModelType.LLAMA3_VISION:
from torchtune.models.llama3_2_vision._convert_weights import (
llama3_vision_tune_to_meta,
)

state_dict[training.MODEL_KEY] = llama3_vision_tune_to_meta(
model_state_dict
)
else:
# llama3_2 has tied weights, so we need to add the output.weight key
if (
self._model_type == ModelType.LLAMA3_2
and "output.weight" not in model_state_dict
):
model_state_dict["output.weight"] = model_state_dict[
"tok_embeddings.weight"
]

state_dict[training.MODEL_KEY] = convert_weights.tune_to_meta(
model_state_dict
)

model_file_name = Path.joinpath(model_file_path, "consolidated.00.pth")

torch.save(state_dict[training.MODEL_KEY], model_file_name)
logger.info(
"Model checkpoint of size "
f"{os.path.getsize(model_file_name) / 1000**3:.2f} GB "
f"saved to {model_file_name}"
)

if training.ADAPTER_KEY in state_dict:
adapter_file_path = model_file_path / "adapter"
adapter_file_path.mkdir(parents=True, exist_ok=True)
adapter_file_name = Path.joinpath(adapter_file_path, "adapter.pth")
torch.save(state_dict[training.ADAPTER_KEY], adapter_file_name)
logger.info(
"Adapter checkpoint of size "
f"{os.path.getsize(adapter_file_name) / 1000**3:.2f} GB "
f"saved to {adapter_file_name}"
)

elif adapter_only:
raise ValueError(
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
)
elif checkpoint_format == "hf":
# Note: for saving hugging face format checkpoints, we only suppport saving adapter weights now

# the config.json file contains model params needed for state dict conversion
config = json.loads(
Path.joinpath(self._checkpoint_dir.parent, "config.json").read_text()
)

if not adapter_only:
model_state_dict = state_dict[training.MODEL_KEY]
if self._model_type == ModelType.LLAMA3_VISION:
from torchtune.models.llama3_2_vision._convert_weights import (
llama3_vision_tune_to_meta,
# repo_id is necessary for when saving an adapter config, so its compatible with HF.
# This json file is produced and saved in the download step.
# contents are {"repo_id": "some_model/some_model_version"}
repo_id_path = Path.joinpath(
self._checkpoint_dir.parent, REPO_ID_FNAME
).with_suffix(".json")
self.repo_id = None
if repo_id_path.exists():
with open(repo_id_path, "r") as json_file:
data = json.load(json_file)
self.repo_id = data.get("repo_id")

if training.ADAPTER_KEY in state_dict:
# TODO: saving it "as is" is a requirement because, if we only save with
# convert_weights.tune_to_peft_adapter_weights, we do NOT have a fn
# convert_weights.peft_to_tune. The .pt format is not needed, but
# it is an easy way to distinguish the adapters. Ideally we should save only one.
output_path = Path.joinpath(
model_file_path, ADAPTER_MODEL_FNAME
).with_suffix(".pt")
output_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(state_dict[training.ADAPTER_KEY], output_path)
logger.info(
"Adapter checkpoint of size "
f"{os.path.getsize(output_path) / 1024**3:.2f} GiB "
f"saved to {output_path}"
)

state_dict[training.MODEL_KEY] = llama3_vision_tune_to_meta(
model_state_dict
state_dict[training.ADAPTER_KEY] = (
convert_weights.tune_to_peft_adapter_weights(
state_dict[training.ADAPTER_KEY],
num_heads=config["num_attention_heads"],
num_kv_heads=config["num_key_value_heads"],
dim=config["hidden_size"],
head_dim=config.get("head_dim", None),
)
)
output_path = Path.joinpath(
model_file_path, "adapter", ADAPTER_MODEL_FNAME
)
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path = output_path.with_suffix(".safetensors")
save_file(
state_dict[training.ADAPTER_KEY],
output_path,
metadata={"format": "pt"},
)
logger.info(
"Adapter checkpoint of size "
f"{os.path.getsize(output_path) / 1024**3:.2f} GiB "
f"saved to {output_path}"
)
else:
# llama3_2 has tied weights, so we need to add the output.weight key
if (
self._model_type == ModelType.LLAMA3_2
and "output.weight" not in model_state_dict
):
model_state_dict["output.weight"] = model_state_dict[
"tok_embeddings.weight"
]

state_dict[training.MODEL_KEY] = convert_weights.tune_to_meta(
model_state_dict
raise ValueError(
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
)

model_file_name = Path.joinpath(model_file_path, "consolidated.00.pth")

torch.save(state_dict[training.MODEL_KEY], model_file_name)
logger.info(
"Model checkpoint of size "
f"{os.path.getsize(model_file_name) / 1000**3:.2f} GB "
f"saved to {model_file_name}"
)
if training.ADAPTER_CONFIG in state_dict:
state_dict[training.ADAPTER_CONFIG] = (
convert_weights.tune_to_peft_adapter_config(
adapter_config=state_dict[training.ADAPTER_CONFIG],
base_model_name_or_path=self.repo_id,
)
)

if training.ADAPTER_KEY in state_dict:
adapter_file_path = model_file_path / "adapter"
adapter_file_path.mkdir(parents=True, exist_ok=True)
adapter_file_name = Path.joinpath(adapter_file_path, "adapter.pth")
torch.save(state_dict[training.ADAPTER_KEY], adapter_file_name)
logger.info(
"Adapter checkpoint of size "
f"{os.path.getsize(adapter_file_name) / 1000**3:.2f} GB "
f"saved to {adapter_file_name}"
)
output_path = Path.joinpath(
model_file_path, "adapter", ADAPTER_CONFIG_FNAME
).with_suffix(".json")
with open(output_path, "w") as f:
json.dump(state_dict[training.ADAPTER_CONFIG], f)
logger.info(
"Adapter checkpoint of size "
f"{os.path.getsize(output_path) / 1024**3:.2f} GiB "
f"saved to {output_path}"
)

elif adapter_only:
raise ValueError(
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
# Save all files in ckpt_dir, except model weights and mapping, to output_dir/epoch_{epoch}
# So its easy to run inference with the model using this epoch's checkpoint
copy_files(
self._checkpoint_dir.parent,
model_file_path,
ignore_suffixes=SUFFIXES_TO_NOT_COPY,
)

print("model_file_path", str(model_file_path))
else:
raise ValueError(f"Unsupported checkpoint format: {format}")

return str(model_file_path)
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@

class TorchtunePostTrainingConfig(BaseModel):
torch_seed: Optional[int] = None
checkpoint_format: Optional[str] = "meta"
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def model_checkpoint_dir(model) -> str:
self.checkpoint_dir = model_checkpoint_dir(model)

self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
self._checkpoint_format = config.checkpoint_format

self.seed = training.set_seed(seed=config.torch_seed)
self.epochs_run = 0
Expand Down Expand Up @@ -444,6 +445,7 @@ async def save_checkpoint(self, epoch: int) -> str:
return self._checkpointer.save_checkpoint(
ckpt_dict,
epoch=epoch,
checkpoint_format=self._checkpoint_format,
)

async def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
Expand Down Expand Up @@ -488,7 +490,7 @@ async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]:
# Update the sampler to ensure data is correctly shuffled across epochs
# in case shuffle is True
metric_logger = DiskLogger(
log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}"
log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}/log"
)
self._training_sampler.set_epoch(curr_epoch)
loss_to_log = 0.0
Expand Down
Loading