From 32cb0fa0c5939c3a61234f896f1fb4375c2e02c3 Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Wed, 19 Feb 2025 12:11:26 -0800 Subject: [PATCH] Conversion script for OLMo 2 models trained with OLMo core to HuggingFace (#158) Verified it works by converting public checkpoint to olmo-core and back, GSM8K score is [the](https://beaker.allen.ai/orgs/ai2/workspaces/oe-data/work/01JKPJB2KZSP849R0C7H50KMAW?taskId=01JKPJB2MGE2SV5ZZJJ431RV2Y&jobId=01JKPJB2RMAGHVVSW1DSH4JQZE) [same](https://beaker.allen.ai/orgs/ai2/workspaces/oe-data/work/01JKPJAG7MEJ48WZQD9DT9XT7Z?taskId=01JKPJAG8515DZ02AM27SR8XKN&jobId=01JKPJAGCCEDH2X9RNNWZ7FFAV). --- CHANGELOG.md | 4 + docs/source/examples/huggingface.rst | 4 +- ...point.py => convert_checkpoint_from_hf.py} | 9 +- .../huggingface/convert_checkpoint_to_hf.py | 297 ++++++++++++++++++ 4 files changed, 311 insertions(+), 3 deletions(-) rename src/examples/huggingface/{convert_checkpoint.py => convert_checkpoint_from_hf.py} (95%) create mode 100644 src/examples/huggingface/convert_checkpoint_to_hf.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 22bb8fa9a..9ef732ef2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Added + +- Added conversion script for OLMo 2 checkpoints to Huggingface format. + ### Fixed - Ensure certain optimizer param group fields are not overridden by the values in a checkpoint. diff --git a/docs/source/examples/huggingface.rst b/docs/source/examples/huggingface.rst index 01baa543f..5a5510f56 100644 --- a/docs/source/examples/huggingface.rst +++ b/docs/source/examples/huggingface.rst @@ -13,7 +13,7 @@ It would be straight forward to adapt this script to convert in the other direct .. seealso:: See the `train a Llama model `_ example to learn how to use OLMo-core's training API to pretrain or fine-tune any Llama-like language model. -.. tab:: ``src/examples/huggingface/convert_checkpoint.py`` +.. tab:: ``src/examples/huggingface/convert_checkpoint_from_hf.py`` - .. literalinclude:: ../../../src/examples/huggingface/convert_checkpoint.py + .. literalinclude:: ../../../src/examples/huggingface/convert_checkpoint_from_hf.py :language: py diff --git a/src/examples/huggingface/convert_checkpoint.py b/src/examples/huggingface/convert_checkpoint_from_hf.py similarity index 95% rename from src/examples/huggingface/convert_checkpoint.py rename to src/examples/huggingface/convert_checkpoint_from_hf.py index 6559ee43c..563cec7de 100644 --- a/src/examples/huggingface/convert_checkpoint.py +++ b/src/examples/huggingface/convert_checkpoint_from_hf.py @@ -6,7 +6,9 @@ HuggingFace. """ +import json import logging +import os import torch from transformers import AutoModelForCausalLM @@ -127,7 +129,10 @@ def convert_checkpoint() -> AutoModelForCausalLM: assert len(state_dict) == 0 log.info(f"Saving converted model checkpoint '{SAVE_PATH}'...") - save_state_dict(SAVE_PATH, {"model": new_state_dict}) + save_state_dict(os.path.join(SAVE_PATH, "model_and_optim"), {"model": new_state_dict}) + + with open(os.path.join(SAVE_PATH, "config.json"), "w") as f: + json.dump({"model": MODEL_CONFIG.as_dict()}, f) return hf_model @@ -159,5 +164,7 @@ def validate_conversion(hf_model): if __name__ == "__main__": prepare_cli_environment() + + config = MODEL_CONFIG.as_dict() hf_model = convert_checkpoint() validate_conversion(hf_model) diff --git a/src/examples/huggingface/convert_checkpoint_to_hf.py b/src/examples/huggingface/convert_checkpoint_to_hf.py new file mode 100644 index 000000000..e0bf5dc4c --- /dev/null +++ b/src/examples/huggingface/convert_checkpoint_to_hf.py @@ -0,0 +1,297 @@ +""" +Example script to convert a OLMo Core model checkpoint to a HuggingFace model checkpoint. +""" + +import json +import logging +import re +import shutil +from argparse import ArgumentParser +from contextlib import contextmanager +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Generator + +import torch +from safetensors.torch import load_file +from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2Tokenizer, Olmo2Config + +from olmo_core.distributed.checkpoint import unshard_checkpoint +from olmo_core.utils import prepare_cli_environment + +log = logging.getLogger(__name__) + + +try: + from accelerate import init_empty_weights +except ImportError: + pass + + @contextmanager + def init_empty_weights(include_buffers: bool = False) -> Generator[None, None, None]: + log.warning("accelerate not installed, will initialize weights.") + yield None + + +def load_state_dict(checkpoint_path: str | Path) -> dict[str, torch.Tensor]: + """ + Load a state dict from either a PyTorch checkpoint or safetensors file. + + Args: + checkpoint_path: Path to checkpoint file (.pt, .pth, .bin, or .safetensors) + + Returns: + Dictionary containing model state dict + """ + checkpoint_path = Path(checkpoint_path) + + if checkpoint_path.suffix == ".safetensors": + # Load safetensors format + state_dict = load_file(checkpoint_path) + else: + # Load PyTorch format (.pt, .pth, .bin) + state_dict = torch.load(checkpoint_path, map_location="cpu") + + # Handle both cases: + # 1. Direct state dict + # 2. Nested state dict under 'model' or 'state_dict' key + if "model" in state_dict: + state_dict = state_dict["model"] + elif "state_dict" in state_dict: + state_dict = state_dict["state_dict"] + + return state_dict + + +def convert_to_hf_checkpoint( + olmo_checkpoint_path: str | Path, + output_path: str | Path, + olmo_core_config: dict, + tokenizer: GPT2Tokenizer, + max_sequence_length: int = -1, +) -> None: + """ + Convert an OLMo core checkpoint to Hugging Face format. + + Args: + olmo_checkpoint_path: Path to the OLMo core checkpoint + output_path: Where to save the converted HF model + olmo_core_config: OLMo core configuration dictionary + tokenizer: HuggingFace tokenizer instance + """ + log.info(f"Loading OLMo core checkpoint from '{olmo_checkpoint_path}'") + olmo_state_dict = load_state_dict(olmo_checkpoint_path) + + # Initialize new state dict for HF format + hf_state_dict = {} + + # Map OLMo-core keys to HF keys + hf_state_dict["model.embed_tokens.weight"] = olmo_state_dict.pop("embeddings.weight") # ok + + if "norm.weight" in olmo_state_dict: + log.info("Using norm.weight as model.norm.weight") + hf_state_dict["model.norm.weight"] = olmo_state_dict.pop("norm.weight") + elif "lm_head.norm.weight" in olmo_state_dict: + log.info("Using lm_head.norm.weight as model.norm.weight") + hf_state_dict["model.norm.weight"] = olmo_state_dict.pop("lm_head.norm.weight") + else: + raise ValueError("No norm.weight or lm_head.norm.weight found in the state dict") + + if "w_out.weight" in olmo_state_dict: + log.info("Using w_out.weight as lm_head.weight") + hf_state_dict["lm_head.weight"] = olmo_state_dict.pop("w_out.weight") + elif "lm_head.w_out.weight" in olmo_state_dict: + log.info("Using lm_head.w_out.weight as lm_head.weight") + hf_state_dict["lm_head.weight"] = olmo_state_dict.pop("lm_head.w_out.weight") + else: + raise ValueError("No w_out.weight or lm_head.w_out.weight found in the state dict") + + # Count number of layers from the state dict keys + layer_ids = [ + match.group(1) + for key in olmo_state_dict.keys() + if (match := re.match(r"blocks\.(\d+)\.", key)) + ] + assert layer_ids, "No layer IDs found in the state dict keys" + n_layers = max(map(int, layer_ids)) + 1 + + for block in range(n_layers): + # Attention + hf_state_dict[f"model.layers.{block}.self_attn.q_proj.weight"] = olmo_state_dict.pop( + f"blocks.{block}.attention.w_q.weight" + ) + hf_state_dict[f"model.layers.{block}.self_attn.k_proj.weight"] = olmo_state_dict.pop( + f"blocks.{block}.attention.w_k.weight" + ) + hf_state_dict[f"model.layers.{block}.self_attn.v_proj.weight"] = olmo_state_dict.pop( + f"blocks.{block}.attention.w_v.weight" + ) + hf_state_dict[f"model.layers.{block}.self_attn.o_proj.weight"] = olmo_state_dict.pop( + f"blocks.{block}.attention.w_out.weight" + ) + + # MLP + hf_state_dict[f"model.layers.{block}.mlp.gate_proj.weight"] = olmo_state_dict.pop( + f"blocks.{block}.feed_forward.w1.weight" + ) + hf_state_dict[f"model.layers.{block}.mlp.down_proj.weight"] = olmo_state_dict.pop( + f"blocks.{block}.feed_forward.w2.weight" + ) + hf_state_dict[f"model.layers.{block}.mlp.up_proj.weight"] = olmo_state_dict.pop( + f"blocks.{block}.feed_forward.w3.weight" + ) + + # Check if we have q_norm and k_norm (non-Llama model) + has_qk_norm = f"blocks.{block}.attention.q_norm.weight" in olmo_state_dict + + if has_qk_norm: + # Non-Llama model + hf_state_dict[ + f"model.layers.{block}.post_attention_layernorm.weight" + ] = olmo_state_dict.pop(f"blocks.{block}.attention_norm.weight") + hf_state_dict[ + f"model.layers.{block}.post_feedforward_layernorm.weight" + ] = olmo_state_dict.pop(f"blocks.{block}.feed_forward_norm.weight") + hf_state_dict[f"model.layers.{block}.self_attn.q_norm.weight"] = olmo_state_dict.pop( + f"blocks.{block}.attention.q_norm.weight" + ) + hf_state_dict[f"model.layers.{block}.self_attn.k_norm.weight"] = olmo_state_dict.pop( + f"blocks.{block}.attention.k_norm.weight" + ) + else: + # Llama model + hf_state_dict[ + f"model.layers.{block}.post_attention_layernorm.weight" + ] = olmo_state_dict.pop(f"blocks.{block}.feed_forward_norm.weight") + hf_state_dict[f"model.layers.{block}.input_layernorm.weight"] = olmo_state_dict.pop( + f"blocks.{block}.attention_norm.weight" + ) + + # Verify we used all keys + if len(olmo_state_dict) > 0: + _unused_keys = "\n".join(f"\t-{k}" for k in olmo_state_dict.keys()) + raise ValueError(f"Unused keys in the state dict:\n{_unused_keys}") + + if not hasattr(tokenizer, "vocab") or not isinstance(tokenizer.vocab, dict): + raise ValueError("Tokenizer must have a vocab dictionary") + + if ( + k := olmo_core_config.get("model", {}) + .get("block", {}) + .get("layer_norm", {}) + .get("name", None) + ) != "rms": + raise ValueError(f"Only RMSNorm is supported, found {k}") + + max_sequence_length = min( + int(olmo_core_config.get("dataset", {}).get("sequence_length", max_sequence_length)), + tokenizer.model_max_length, + ) + + if max_sequence_length <= 0: + raise ValueError(f"Missing or invalid sequence length: {max_sequence_length}") + + # Create HF model instance and load state dict + huggingface_config = Olmo2Config( + vocab_size=olmo_core_config["model"]["vocab_size"], + hidden_size=olmo_core_config["model"]["d_model"], + intermediate_size=olmo_core_config["model"]["block"]["feed_forward"]["hidden_size"], + num_hidden_layers=olmo_core_config["model"]["n_layers"], + num_attention_heads=(n_heads := olmo_core_config["model"]["block"]["attention"]["n_heads"]), + num_key_value_heads=( + olmo_core_config["model"]["block"]["attention"].get("n_kv_heads") or n_heads + ), + hidden_act="silu", + max_position_embeddings=max_sequence_length, + rope_theta=olmo_core_config["model"]["block"]["attention"]["rope"]["theta"], + attention_bias=olmo_core_config["model"]["block"]["attention"].get("bias") or False, + pad_token_id=tokenizer.vocab.get(tokenizer.pad_token, None), + bos_token_id=tokenizer.vocab.get(tokenizer.bos_token, None), + eos_token_id=tokenizer.vocab.get(tokenizer.eos_token, None), + rms_norm_eps=olmo_core_config["model"]["block"]["layer_norm"]["eps"], + tie_word_embeddings=False, + ) + + with init_empty_weights(): + log.info("Initializing HF model with empty weights...") + model = AutoModelForCausalLM.from_config(huggingface_config) + + log.info("Loading state dict into HF model...") + model.load_state_dict(hf_state_dict, assign=True) + + # Save the model in HF format + log.info(f"Saving HF model checkpoint to {output_path}...") + model.save_pretrained(output_path) + log.info(f"Successfully saved HF model to '{output_path}'") + + # Save the tokenizer in HF format + tokenizer.save_pretrained(output_path) + log.info(f"Successfully saved HF tokenizer to '{output_path}'") + + +def load_config(checkpoint_input_dir: Path) -> dict: + assert ( + checkpoint_input_dir / "config.json" + ).exists(), f"Config file not found at {checkpoint_input_dir}" + + with open(checkpoint_input_dir / "config.json", "r", encoding="utf-8") as f: + config_dict = json.load(f) + + return config_dict + + +def parse_args(): + parser = ArgumentParser(description=__doc__) + parser.add_argument("-i", "--checkpoint-input-dir", type=Path, required=True) + parser.add_argument("-u", "--unsharded-output-dir", type=Path, default=None) + parser.add_argument("-o", "--huggingface-output-dir", type=Path, required=True) + parser.add_argument("-t", "--tokenizer-name-or-path", type=str, default=None) + parser.add_argument("-s", "--max-sequence-length", type=int, default=-1) + return parser.parse_args() + + +def main(): + args = parse_args() + experiment_config = load_config(args.checkpoint_input_dir) + + if args.tokenizer_name_or_path is None: + args.tokenizer_name_or_path = ( + experiment_config.get("dataset", {}).get("tokenizer", {}).get("identifier", None) + ) + + if args.tokenizer_name_or_path is None: + raise ValueError( + "Tokenizer identifier not found in the config; please provide it manually" + ) + + tokenizer_config = AutoTokenizer.from_pretrained(args.tokenizer_name_or_path) + + with TemporaryDirectory() as _unsharded_dir: + if args.unsharded_output_dir: + log.info(f"Using provided unsharded output directory: {args.unsharded_output_dir}") + _unsharded_dir = args.unsharded_output_dir + + shards_dir = args.checkpoint_input_dir / "model_and_optim" + if shards_dir.exists() and shards_dir.is_dir(): + logging.info(f"Unsharding checkpoint from {shards_dir} to {_unsharded_dir}") + (unsharded_dir := Path(_unsharded_dir)).mkdir(parents=True, exist_ok=True) + unshard_checkpoint(dir=shards_dir, target_dir=unsharded_dir, optim=False) + + logging.info("Copying config.json to unsharded directory") + shutil.copy(args.checkpoint_input_dir / "config.json", unsharded_dir / "config.json") + else: + logging.info("No sharded checkpoint found, using input directory as unsharded") + unsharded_dir = args.checkpoint_input_dir + + convert_to_hf_checkpoint( + olmo_checkpoint_path=unsharded_dir / "model.pt", + output_path=args.huggingface_output_dir, + olmo_core_config=experiment_config, + max_sequence_length=args.max_sequence_length, + tokenizer=tokenizer_config, # type: ignore + ) + + +if __name__ == "__main__": + prepare_cli_environment() + main()