diff --git a/CHANGELOG.md b/CHANGELOG.md index b73eeae96..9f0472818 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +- Added support for safetensors in `hf_olmo` conversion script. + ## [v0.5.1](https://github.com/allenai/OLMo/releases/tag/v0.5.1) - 2024-10-17 ### Added @@ -45,7 +47,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Swapped in correct flan data mix. - Fix bug where the attention norm, when applied before the attention block, was modifying the residual stream. - Fixed `OLMo.from_checkpoint()` so that it correctly loads `olmo_core` and `torch_new` style checkpoints. -- Fixed `preserve_rng_state` being incorrectly set to False when doing gradient checkpointing with dropout +- Fixed `preserve_rng_state` being incorrectly set to False when doing gradient checkpointing with dropout ## [v0.4.0](https://github.com/allenai/OLMo/releases/tag/v0.4.0) - 2024-07-11 diff --git a/hf_olmo/convert_olmo_to_hf.py b/hf_olmo/convert_olmo_to_hf.py index 2e0a9e074..e00f4cbe0 100644 --- a/hf_olmo/convert_olmo_to_hf.py +++ b/hf_olmo/convert_olmo_to_hf.py @@ -1,10 +1,10 @@ import argparse -import logging import os import re import shutil import tempfile from hashlib import md5 +from pathlib import Path from typing import Iterable, Optional from urllib.parse import urlparse @@ -16,10 +16,11 @@ from hf_olmo.modeling_olmo import OLMoForCausalLM from hf_olmo.tokenization_olmo_fast import OLMoTokenizerFast from olmo import ModelConfig, Tokenizer, TrainConfig +from olmo.aliases import PathOrStr from olmo.checkpoint import build_sharded_checkpointer -from olmo.util import _get_s3_client +from olmo.safetensors_util import safetensors_file_to_state_dict +from olmo.util import _get_gcs_client, _get_s3_client -logger = logging.getLogger(__name__) HF_FILENAMES = { "config.json", @@ -30,6 +31,12 @@ } +def walk_local_path(path: PathOrStr, top_down=True, on_error=None, follow_symlinks=False): + """Necessary because Path.walk() was only added in python 3.12""" + for root, dirs, files in os.walk(str(path), topdown=top_down, onerror=on_error, followlinks=follow_symlinks): + yield Path(root), dirs, files + + def longest_common_prefix(strs: Iterable[str]) -> str: """ Finds the longest common prefix among a list of strings. @@ -48,29 +55,37 @@ def longest_common_prefix(strs: Iterable[str]) -> str: return shortest_str -def write_config(checkpoint_dir: str): +def write_config(checkpoint_dir: str, destination_dir: str): # save config as HF config - logger.info(f"Loading checkpoint from {checkpoint_dir}") + print(f"Loading checkpoint from {checkpoint_dir}") + + if os.path.exists(os.path.join(destination_dir, "config.yaml")): + config_path = os.path.join(destination_dir, "config.yaml") + else: + config_path = os.path.join(checkpoint_dir, "config.yaml") - config_path = os.path.join(checkpoint_dir, "config.yaml") model_config = ModelConfig.load(config_path, key="model") config_kwargs = model_config.asdict() config_kwargs["use_cache"] = True config = OLMoConfig(**config_kwargs) - logger.info(f"Saving HF-compatible config to {os.path.join(checkpoint_dir, 'config.json')}") - config.save_pretrained(checkpoint_dir) + print(f"Saving HF-compatible config to {os.path.join(destination_dir, 'config.json')}") + config.save_pretrained(destination_dir) -def write_model(checkpoint_dir: str, ignore_olmo_compatibility: bool = False): +def write_model(checkpoint_dir: str, destination_dir: str, ignore_olmo_compatibility: bool = False): # For device_map = "auto", etc. the models are loaded in a way that start_prefix is not computed correctly. # So, we explicitly store the model with the expected prefix. - old_model_path = os.path.join(checkpoint_dir, "model.pt") - new_model_path = os.path.join(checkpoint_dir, "pytorch_model.bin") + if os.path.exists(old_model_path := os.path.join(checkpoint_dir, "model.pt")): + state_dict = torch.load(old_model_path, map_location="cpu") + elif os.path.exists(old_model_path := os.path.join(checkpoint_dir, "model.safetensors")): + state_dict = safetensors_file_to_state_dict(old_model_path, map_location="cpu") + else: + raise ValueError(f"No model found in {checkpoint_dir}") - state_dict = torch.load(old_model_path, map_location="cpu") + new_model_path = os.path.join(destination_dir, "pytorch_model.bin") # this takes care of the case where the model was saved with a different prefix, # typically due to unsharding. @@ -85,7 +100,7 @@ def write_model(checkpoint_dir: str, ignore_olmo_compatibility: bool = False): os.remove(old_model_path) -def write_tokenizer(checkpoint_dir: str): +def write_tokenizer(checkpoint_dir: str, destination_dir: str): tokenizer_raw = Tokenizer.from_checkpoint(checkpoint_dir) tokenizer = OLMoTokenizerFast( tokenizer_object=tokenizer_raw.base_tokenizer, @@ -96,33 +111,37 @@ def write_tokenizer(checkpoint_dir: str): tokenizer.model_input_names = ["input_ids", "attention_mask"] tokenizer.pad_token_id = tokenizer_raw.pad_token_id tokenizer.eos_token_id = tokenizer_raw.eos_token_id - - tokenizer.save_pretrained(checkpoint_dir) + tokenizer.save_pretrained(destination_dir) -def convert_checkpoint(checkpoint_dir: str, ignore_olmo_compatibility: bool = False): +def convert_checkpoint(checkpoint_dir: str, destination_dir: str, ignore_olmo_compatibility: bool = False): print("Converting checkpoint to HF format...") - write_config(checkpoint_dir) + write_config(checkpoint_dir=checkpoint_dir, destination_dir=destination_dir) print("Saving model to checkpoint...") - write_model(checkpoint_dir, ignore_olmo_compatibility=ignore_olmo_compatibility) + write_model( + checkpoint_dir=checkpoint_dir, + destination_dir=destination_dir, + ignore_olmo_compatibility=ignore_olmo_compatibility + ) print("Saving tokenizer to checkpoint...") - write_tokenizer(checkpoint_dir) + write_tokenizer(checkpoint_dir=checkpoint_dir, destination_dir=destination_dir) # Cannot remove it before writing the tokenizer if ignore_olmo_compatibility: - os.remove(os.path.join(checkpoint_dir, "config.yaml")) + os.remove(os.path.join(destination_dir, "config.yaml")) -def fix_tokenizer(checkpoint_dir: str, tokenizer_name_or_path: Optional[str] = None): - path = os.path.join(checkpoint_dir, "config.yaml") - conf = om.load(path) +def fix_tokenizer(checkpoint_dir: str, destination_dir: str, tokenizer_name_or_path: Optional[str] = None): + Path(destination_dir).mkdir(parents=True, exist_ok=True) - print("Saving tokenizer to checkpoint...") + source_path = os.path.join(checkpoint_dir, "config.yaml") + dest_path = os.path.join(destination_dir, "config.yaml") + conf = om.load(source_path) + print(f"Saving saving new tokenizer configuration to {dest_path}") tokenizer_name_or_path = str(tokenizer_name_or_path or conf["tokenizer"]["identifier"]) # pyright: ignore - try: if os.path.exists(tokenizer_name_or_path): Tokenizer.from_file(tokenizer_name_or_path) @@ -130,7 +149,7 @@ def fix_tokenizer(checkpoint_dir: str, tokenizer_name_or_path: Optional[str] = N Tokenizer.from_pretrained(tokenizer_name_or_path) except Exception as e: # the tokenizer is not valid - logger.error(f"Invalid tokenizer: {tokenizer_name_or_path}. Error: {e}") + print(f"Invalid tokenizer: {tokenizer_name_or_path}. Error: {e}") raise e conf["tokenizer"]["identifier"] = tokenizer_name_or_path # pyright: ignore @@ -140,7 +159,24 @@ def fix_tokenizer(checkpoint_dir: str, tokenizer_name_or_path: Optional[str] = N ): conf["model"]["eos_token_id"] = 50279 # pyright: ignore - om.save(conf, path) + om.save(conf, dest_path) + + +def download_gcs_directory(bucket_name: str, prefix: str, local_dir: str): + path_local = Path(local_dir) + path_prefix = Path(prefix) + + gcs_client = _get_gcs_client() + bucket = gcs_client.bucket(bucket_name) + + path_local.mkdir(parents=True, exist_ok=True) + + files_to_download = list(bucket.list_blobs(prefix=prefix)) + + for elem in tqdm(files_to_download, desc="Downloading files from GCS"): + local_destination = path_local / Path(elem.name).relative_to(path_prefix) + local_destination.parent.mkdir(parents=True, exist_ok=True) + elem.download_to_filename(local_destination) def download_s3_directory(bucket_name: str, prefix: str, local_dir: str, ignore: str | None = None): @@ -162,7 +198,7 @@ def download_s3_directory(bucket_name: str, prefix: str, local_dir: str, ignore: files_to_download.append(obj["Key"]) # Initialize the progress bar - for s3_key in tqdm(files_to_download, desc="Downloading files"): + for s3_key in tqdm(files_to_download, desc="Downloading files from S3"): # Construct the full local path local_file_path = os.path.join(local_dir, os.path.relpath(s3_key, prefix)) local_file_dir = os.path.dirname(local_file_path) @@ -178,7 +214,7 @@ def download_s3_directory(bucket_name: str, prefix: str, local_dir: str, ignore: def make_local_checkpoint(checkpoint_dir: str) -> str: parsed_dir = urlparse(checkpoint_dir) - assert parsed_dir.scheme in ["s3", ""], "Only s3 and local paths are supported." + assert parsed_dir.scheme in ["s3", "gs", "", "file"], "Only s3, gcs, and local paths are supported." if os.path.exists(checkpoint_dir): return checkpoint_dir @@ -189,51 +225,105 @@ def make_local_checkpoint(checkpoint_dir: str) -> str: try: os.makedirs(temp_dir, exist_ok=True) print(f"Downloading checkpoint to {temp_dir}...") - download_s3_directory( - bucket_name=parsed_dir.netloc, - prefix=parsed_dir.path.lstrip("/"), - local_dir=temp_dir, - ignore=r"/(optim|train)/", - ) + + if parsed_dir.scheme == "gs": + download_gcs_directory( + bucket_name=parsed_dir.netloc, + prefix=parsed_dir.path.lstrip("/"), + local_dir=temp_dir, + ) + elif parsed_dir.scheme == "s3": + download_s3_directory( + bucket_name=parsed_dir.netloc, + prefix=parsed_dir.path.lstrip("/"), + local_dir=temp_dir, + ignore=r"/(optim|train)/", + ) + else: + raise ValueError(f"Unsupported: {checkpoint_dir}. Only s3://, gs://, and local are supported.") except Exception as e: - logger.error(f"Error downloading checkpoint: {e}") + print(f"Error downloading checkpoint: {e}") shutil.rmtree(temp_dir) raise e return temp_dir +def upload_s3_directory(local_checkpoint_dir: str, destination_dir: str): + parsed_destination = urlparse(destination_dir) + if parsed_destination.scheme != "s3": + raise ValueError(f"Unsupported destination: {destination_dir}. Only s3 paths are supported.") + + s3_client = _get_s3_client("s3") + s3_bucket_name = parsed_destination.netloc + s3_prefix = Path(parsed_destination.path) + local_checkpoint_path = Path(local_checkpoint_dir) + local_paths = [ + Path(path / fn) for path, _, filenames in walk_local_path(local_checkpoint_path) for fn in filenames + ] + + for local_path in tqdm(local_paths, desc="Uploading files to S3"): + destination = s3_prefix / local_path.relative_to(local_checkpoint_path) + s3_client.upload_file(local_path, s3_bucket_name, str(destination)) + + +def upload_gcs_directory(local_checkpoint_dir: str, destination_dir: str): + parsed_destination = urlparse(destination_dir) + if parsed_destination.scheme != "gs": + raise ValueError(f"Unsupported destination: {destination_dir}. Only gs paths are supported.") + + gcs_client = _get_gcs_client() + bucket_name = parsed_destination.netloc + prefix = Path(parsed_destination.path) + local_checkpoint_path = Path(local_checkpoint_dir) + local_paths = [ + Path(path / fn) for path, _, filenames in walk_local_path(local_checkpoint_path) for fn in filenames + ] + + bucket = gcs_client.bucket(bucket_name) + + for local_path in tqdm(local_paths, desc="Uploading files to GCS"): + destination = prefix / local_path.relative_to(local_checkpoint_path) + blob = bucket.blob(str(destination)) + blob.upload_from_filename(local_path) + + def upload_local_checkpoint(local_checkpoint_dir: str, destination_dir: str): if destination_dir == local_checkpoint_dir: return - elif (parsed_url := urlparse(destination_dir)).scheme == "s3": - s3_bucket_name = parsed_url.netloc - s3_prefix = parsed_url.path[1:] - - local_paths = [ - os.path.join(root, post_fn) - for root, _, files in os.walk(local_checkpoint_dir) - for post_fn in files - if os.path.basename(post_fn) in HF_FILENAMES - ] - dest_paths = [ - os.path.join(s3_prefix, os.path.relpath(local_path, local_checkpoint_dir)) - for local_path in local_paths - ] - - s3_client = _get_s3_client("s3") - for local_path, dest_path in tqdm( - zip(local_paths, dest_paths), desc="Uploading files", total=len(local_paths) - ): - s3_client.upload_file(local_path, s3_bucket_name, dest_path) - elif parsed_url.scheme == "": - shutil.copytree(local_checkpoint_dir, destination_dir) - else: - raise ValueError(f"Unsupported destination: {destination_dir}. Only s3 and local paths are supported.") + if (parsed_url := urlparse(destination_dir)).scheme == "s3": + return upload_s3_directory(local_checkpoint_dir, destination_dir) + + elif parsed_url.scheme == "gs": + return upload_gcs_directory(local_checkpoint_dir, destination_dir) -def maybe_unshard(checkpoint_dir: str): + # if parsed_url.scheme in ("file", ""): + + breakpoint() + + raise ValueError(f"Unsupported protocol: {destination_dir}. Only s3://, gs://, and local are supported.") + + +def maybe_unshard(checkpoint_dir: str, destination_dir: str): if os.path.exists(os.path.join(checkpoint_dir, "model.pt")): + # copy the model.pt to the destination directory + if checkpoint_dir != destination_dir: + print("Copying model.pt to destination directory...") + shutil.copy(os.path.join(checkpoint_dir, "model.pt"), os.path.join(destination_dir, "model.pt")) + + print("model.pt found; skipping unsharding.") + return + + if os.path.exists(os.path.join(checkpoint_dir, "model.safetensors")): + # copy the model.safetensors to the destination directory + if checkpoint_dir != destination_dir: + print("Copying model.safetensors to destination directory...") + shutil.copy( + os.path.join(checkpoint_dir, "model.safetensors"), + os.path.join(destination_dir, "model.safetensors") + ) + print("model.savetensors found; skipping unsharding.") return print(f"Unsharding {checkpoint_dir}...") @@ -268,12 +358,6 @@ def main(): help="Ignore compatibility with the olmo codebase. " "This will remove files that are needed specifically for olmo codebase, eg. config.yaml, etc.", ) - parser.add_argument( - "--logger-level", - default="warning", - help="Set the logger level.", - ) - parser.add_argument( "--tokenizer", help="Override the tokenizer to use for the checkpoint.", @@ -285,29 +369,48 @@ def main(): ) args = parser.parse_args() + local_destination_dir = args.destination_dir or args.checkpoint_dir - args.destination_dir = args.destination_dir or args.checkpoint_dir - logging.basicConfig() - logger.setLevel(logging.getLevelName(args.logger_level.upper())) + try: + local_checkpoint_dir = make_local_checkpoint(args.checkpoint_dir) - local_checkpoint_dir = make_local_checkpoint(args.checkpoint_dir) - args.checkpoint_dir = local_checkpoint_dir - maybe_unshard(local_checkpoint_dir) + if local_checkpoint_dir != args.checkpoint_dir: + # if using a remote checkpoint, save the converted checkpoint locally + print("Remote checkpoint; using local directory as destination.") + local_destination_dir = local_checkpoint_dir - fix_tokenizer(checkpoint_dir=local_checkpoint_dir, tokenizer_name_or_path=args.tokenizer) - convert_checkpoint(args.checkpoint_dir, args.ignore_olmo_compatibility) + Path(args.destination_dir).mkdir(parents=True, exist_ok=True) + maybe_unshard(checkpoint_dir=local_checkpoint_dir, destination_dir=local_destination_dir) - if not args.keep_olmo_artifacts: - print("Removing non-HF artifacts...") - os.remove(os.path.join(local_checkpoint_dir, "config.yaml")) - os.remove(os.path.join(local_checkpoint_dir, "model.pt")) - shutil.rmtree(os.path.join(local_checkpoint_dir, "optim"), ignore_errors=True) - shutil.rmtree(os.path.join(local_checkpoint_dir, "model"), ignore_errors=True) - shutil.rmtree(os.path.join(local_checkpoint_dir, "train"), ignore_errors=True) + fix_tokenizer( + checkpoint_dir=local_checkpoint_dir, + destination_dir=local_destination_dir, + tokenizer_name_or_path=args.tokenizer + ) - upload_local_checkpoint(local_checkpoint_dir, args.destination_dir) + convert_checkpoint( + checkpoint_dir=args.checkpoint_dir, + destination_dir=local_destination_dir, + ignore_olmo_compatibility=args.ignore_olmo_compatibility + ) + + if not args.keep_olmo_artifacts: + print("Removing non-HF artifacts...") + os.remove(os.path.join(local_checkpoint_dir, "config.yaml")) + os.remove(os.path.join(local_checkpoint_dir, "model.pt")) + shutil.rmtree(os.path.join(local_checkpoint_dir, "optim"), ignore_errors=True) + shutil.rmtree(os.path.join(local_checkpoint_dir, "model"), ignore_errors=True) + shutil.rmtree(os.path.join(local_checkpoint_dir, "train"), ignore_errors=True) + + upload_local_checkpoint(local_destination_dir, args.destination_dir) - print(f"Converted checkpoint saved to {args.destination_dir}") + print(f"Converted checkpoint saved to {args.destination_dir}") + except Exception as e: + print(f"Error converting checkpoint: {e}") + if args.checkpoint_dir != local_destination_dir: + print("Removing partially converted checkpoint...") + shutil.rmtree(args.destination_dir) + raise e if __name__ == "__main__": diff --git a/olmo/util.py b/olmo/util.py index aad77eb1c..5ba85a9a6 100644 --- a/olmo/util.py +++ b/olmo/util.py @@ -503,6 +503,13 @@ def _get_s3_endpoint_url(scheme: str) -> Optional[str]: raise NotImplementedError(f"Cannot get endpoint url for scheme {scheme}") +@cache +def _get_gcs_client(): + from google.cloud import storage as gcs + + return gcs.Client() + + @cache def _get_s3_client(scheme: str): session = boto3.Session(profile_name=_get_s3_profile_name(scheme)) @@ -637,7 +644,11 @@ def _http_file_size(scheme: str, host_name: str, path: str) -> int: import requests response = requests.head(f"{scheme}://{host_name}/{path}", allow_redirects=True) - return int(response.headers.get("content-length")) + + if (content_length := response.headers.get("content-length")) is not None: + return int(content_length) + + raise OLMoNetworkError(f"Failed to get {scheme} file size") def _http_get_bytes_range(scheme: str, host_name: str, path: str, bytes_start: int, num_bytes: int) -> bytes: @@ -647,9 +658,10 @@ def _http_get_bytes_range(scheme: str, host_name: str, path: str, bytes_start: i f"{scheme}://{host_name}/{path}", headers={"Range": f"bytes={bytes_start}-{bytes_start+num_bytes-1}"} ) result = response.content - assert ( - len(result) == num_bytes - ), f"expected {num_bytes} bytes, got {len(result)}" # Some web servers silently ignore range requests and send everything + + # Some web servers silently ignore range requests and send everything + assert len(result) == num_bytes, f"expected {num_bytes} bytes, got {len(result)}" + return result diff --git a/scripts/olmo_soup.py b/scripts/olmo_soup.py new file mode 100644 index 000000000..64e037ede --- /dev/null +++ b/scripts/olmo_soup.py @@ -0,0 +1,123 @@ +""" +Soups OLMo checkpoints. + +Example usage: + +```bash + python scripts/olmo_soup.py -c \ + /weka/oe-training-default/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-moremath-dclm07-fw2-se-flan/step11931 \ + /weka/oe-training-default/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-moremath-dclm07-fw2-se-flan-seed2/step11931 \ + /weka/oe-training-default/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-moremath-dclm07-fw2-se-flan-seed3/step11931 \ + -o /weka/oe-training-default/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-moremath-dclm07-fw2-se-flan-soup/step11931 +``` + +Author: Luca Soldaini (@soldni) + +""" # noqa + + +import argparse +from enum import Enum +from pathlib import Path + +import torch +from tqdm import tqdm + +from olmo.checkpoint import build_sharded_checkpointer +from olmo.config import TrainConfig +from olmo.safetensors_util import safetensors_file_to_state_dict + + +class SoupType(Enum): + uniform = "uniform" + + +def load_checkpoint(path: Path) -> dict[str, torch.Tensor]: + if path.exists() and path.is_file(): + return torch.load(path, map_location="cpu", weights_only=True) + + if (path / "model.pt").exists(): + return torch.load(path / "model.pt", map_location="cpu", weights_only=True) + + if (path / "model.safetensors").exists(): + safetensors_file_to_state_dict(path / "model.safetensors") + + if (path / "model").exists() and (config_path := (path / "config.yaml")).exists(): + train_config = TrainConfig.load(config_path) + checkpointer = build_sharded_checkpointer(train_config) + model_state, _, _ = checkpointer.unshard_checkpoint( + load_path=str(path), load_optimizer_state=False, load_trainer_state=False + ) + return model_state + + raise FileNotFoundError(f"Could not find checkpoint in {path}") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Soup OLMo checkponts") + parser.add_argument( + "-c", + "--checkpoints", + type=Path, + required=True, + nargs="+", + help="Path to checkpoint(s) to soup", + ) + parser.add_argument( + "-s", + "--soup-type", + type=SoupType, + default=SoupType.uniform, + help=f"Methods for checkpoint souping. Choose from: {', '.join(SoupType.__members__.keys())}", + ) + parser.add_argument( + "-o", + "--output", + type=Path, + required=True, + help="Path to save the souped checkpoint", + ) + opts = parser.parse_args() + return opts + + +def main(): + args = parse_args() + + checkpoint_average: dict[str, torch.Tensor] = {} + + for path in tqdm(args.checkpoints, desc="Loading checkpoints", position=0): + state_dict = load_checkpoint(path) + + if len(checkpoint_average) == 0: + # initialize checkpoint_average with zeros + checkpoint_average = {k: torch.zeros_like(v) for k, v in state_dict.items()} + + if any(k not in state_dict for k in checkpoint_average.keys()) or any( + k not in checkpoint_average for k in state_dict.keys() + ): + raise ValueError(f"Checkpoint {path} has different keys") + + for k in tqdm(state_dict, desc="Summing checkpoints", position=1): + if state_dict[k].shape != checkpoint_average[k].shape: + raise ValueError(f"Checkpoint {path} has different shape for key {k}") + checkpoint_average[k] += state_dict[k] / len(args.checkpoints) + + # free memory + del state_dict + + print(f"Saving averaged checkpoint to {args.output}") + # save the averaged checkpoint + args.output.mkdir(parents=True, exist_ok=True) + torch.save(checkpoint_average, args.output / "model.pt") + + print("Copying config.yaml") + # copy the config file + if (config_path := args.checkpoints[0] / "config.yaml").exists(): + with open(config_path, "r") as src_f, open(args.output / "config.yaml", "w") as dst_f: + dst_f.write(src_f.read()) + print("Done!") + + +if __name__ == "__main__": + main()