diff --git a/pyproject.toml b/pyproject.toml index 9d667260..9e315d6e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,8 @@ Source = "https://github.com/opendatahub-io/vllm-tgis-adapter" [project.scripts] grpc_healthcheck = "vllm_tgis_adapter.healthcheck:cli" +model-util = "vllm_tgis_adapter.tgis_utils.scripts:cli" +text-generation-server = "vllm_tgis_adapter.tgis_utils.scripts:cli" [project.optional-dependencies] tests = [ @@ -82,7 +84,10 @@ vllm_tgis_adapter = [ ] [tool.pytest.ini_options] -addopts = "-ra" +addopts = "-ra -k \"not hf_data\"" +markers = [ + "hf_data: marks tests that download data from HF hub (deselect with '-m \"not hf_data\"')" +] [tool.coverage.run] branch = true diff --git a/src/vllm_tgis_adapter/tgis_utils/hub.py b/src/vllm_tgis_adapter/tgis_utils/hub.py new file mode 100644 index 00000000..89096cd6 --- /dev/null +++ b/src/vllm_tgis_adapter/tgis_utils/hub.py @@ -0,0 +1,221 @@ +from __future__ import annotations + +import concurrent +import datetime +import json +import os +from concurrent.futures import ThreadPoolExecutor +from functools import partial +from pathlib import Path + +import torch +from huggingface_hub import HfApi, hf_hub_download, try_to_load_from_cache +from huggingface_hub.utils import LocalEntryNotFoundError +from safetensors.torch import _remove_duplicate_names, load_file, save_file +from tqdm import tqdm + +from vllm_tgis_adapter.logging import init_logger + +logger = init_logger(__name__) + + +def weight_hub_files( + model_name: str, + extension: str = ".safetensors", + revision: str | None = None, + auth_token: str | None = None, +) -> list: + """Get the safetensors filenames on the hub.""" + exts = [extension] if isinstance(extension, str) else extension + api = HfApi() + info = api.model_info(model_name, revision=revision, token=auth_token) + filenames = [ + s.rfilename + for s in info.siblings + if any( + s.rfilename.endswith(ext) + and len(s.rfilename.split("/")) == 1 + and "arguments" not in s.rfilename + and "args" not in s.rfilename + and "training" not in s.rfilename + for ext in exts + ) + ] + return filenames + + +def weight_files( + model_name: str, extension: str = ".safetensors", revision: str | None = None +) -> list: + """Get the local safetensors filenames.""" + filenames = weight_hub_files(model_name, extension) + files = [] + for filename in filenames: + cache_file = try_to_load_from_cache( + model_name, filename=filename, revision=revision + ) + if cache_file is None: + raise LocalEntryNotFoundError( + f"File {filename} of model {model_name} not found in " + f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. " + f"Please run `vllm \ + download-weights {model_name}` first." + ) + files.append(cache_file) + + return files + + +def download_weights( + model_name: str, + extension: str = ".safetensors", + revision: str | None = None, + auth_token: str | None = None, +) -> list: + """Download the safetensors files from the hub.""" + filenames = weight_hub_files( + model_name, extension, revision=revision, auth_token=auth_token + ) + + download_function = partial( + hf_hub_download, + repo_id=model_name, + local_files_only=False, + revision=revision, + token=auth_token, + ) + + logger.info("Downloading %s files for model %s", len(filenames), model_name) + executor = ThreadPoolExecutor(max_workers=min(16, os.cpu_count())) + futures = [ + executor.submit(download_function, filename=filename) for filename in filenames + ] + files = [ + future.result() + for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)) + ] + + return files + + +def get_model_path(model_name: str, revision: str | None = None) -> str: + """Get path to model dir in local huggingface hub (model) cache.""" + config_file = "config.json" + config_path = try_to_load_from_cache( + model_name, + config_file, + cache_dir=os.getenv( + "TRANSFORMERS_CACHE" + ), # will fall back to HUGGINGFACE_HUB_CACHE + revision=revision, + ) + if config_path is not None: + return config_path.removesuffix(f"/{config_file}") + if Path(f"{model_name}/{config_file}").is_file(): + return model_name # Just treat the model name as an explicit model path + + raise ValueError(f"Weights not found in local cache for model {model_name}") + + +def local_weight_files(model_path: str, extension: str = ".safetensors") -> list[Path]: + """Get the local safetensors filenames.""" + ext = "" if extension is None else extension + return list(Path(f"{model_path}").glob(f"*{ext}")) + + +def local_index_files(model_path: str, extension: str = ".safetensors") -> list[Path]: + """Get the local .index.json filename.""" + ext = "" if extension is None else extension + return list(Path(f"{model_path}").glob(f"*{ext}.index.json")) + + +def convert_file(pt_file: Path, sf_file: Path, discard_names: list[str]) -> None: + """Convert a pytorch file to a safetensors file. + + This will remove duplicate tensors from the file. Unfortunately, this might not + respect *transformers* convention forcing us to check for potentially different + keys during load when looking for specific tensors (making tensor sharing explicit). + """ + loaded = torch.load(pt_file, map_location="cpu") + if "state_dict" in loaded: + loaded = loaded["state_dict"] + to_removes = _remove_duplicate_names(loaded, discard_names=discard_names) + + metadata = {"format": "pt"} + for kept_name, to_remove_group in to_removes.items(): + for to_remove in to_remove_group: + if to_remove not in metadata: + metadata[to_remove] = kept_name + del loaded[to_remove] + # Force tensors to be contiguous + loaded = {k: v.contiguous() for k, v in loaded.items()} + + sf_file.parent.mkdir(parents=True, exist_ok=True) + save_file(loaded, sf_file, metadata=metadata) + reloaded = load_file(sf_file) + for k in loaded: + pt_tensor = loaded[k] + sf_tensor = reloaded[k] + if not torch.equal(pt_tensor, sf_tensor): + raise RuntimeError(f"The output tensors do not match for key {k}") + + +def convert_index_file( + source_file: Path, dest_file: Path, pt_files: list[Path], sf_files: list[Path] +) -> None: + weight_file_map = {s.name: d.name for s, d in zip(pt_files, sf_files)} + + logger.info("Converting pytorch .bin.index.json files to .safetensors.index.json") + with open(source_file) as f: + index = json.load(f) + + index["weight_map"] = { + k: weight_file_map[v] for k, v in index["weight_map"].items() + } + + with open(dest_file, "w") as f: + json.dump(index, f) + + +def convert_files( + pt_files: list[Path], sf_files: list[Path], discard_names: list[str] | None = None +) -> None: + assert len(pt_files) == len(sf_files) + + # Filter non-inference files + pairs = [ + p + for p in zip(pt_files, sf_files) + if not any( + s in p[0].name + for s in [ + "arguments", + "args", + "training", + "optimizer", + "scheduler", + "index", + ] + ) + ] + + n = len(pairs) + + if n == 0: + logger.warning("No pytorch .bin weight files found to convert") + return + + logger.info("Converting %d pytorch .bin files to .safetensors...", n) + + for i, (pt_file, sf_file) in enumerate(pairs): + file_count = (i + 1) / n + logger.info('Converting: [%d] "%s"', file_count, pt_file.name) + start = datetime.datetime.now(tz=datetime.UTC) + convert_file(pt_file, sf_file, discard_names) + elapsed = datetime.datetime.now(tz=datetime.UTC) - start + logger.info( + 'Converted: [%d] "%s" -- Took: %d seconds', + file_count, + sf_file.name, + elapsed.total_seconds(), + ) diff --git a/src/vllm_tgis_adapter/tgis_utils/scripts.py b/src/vllm_tgis_adapter/tgis_utils/scripts.py new file mode 100644 index 00000000..0bba683d --- /dev/null +++ b/src/vllm_tgis_adapter/tgis_utils/scripts.py @@ -0,0 +1,232 @@ +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING + +from vllm.utils import FlexibleArgumentParser + +from vllm_tgis_adapter.logging import init_logger +from vllm_tgis_adapter.tgis_utils import hub + +logger = init_logger(__name__) + +if TYPE_CHECKING: + import argparse + + +def tgis_cli(args: argparse.Namespace) -> None: + if args.command == "download-weights": + download_weights( + args.model_name, + args.revision, + args.token, + args.extension, + args.auto_convert, + ) + elif args.command == "convert-to-safetensors": + convert_to_safetensors(args.model_name, args.revision) + elif args.command == "convert-to-fast-tokenizer": + convert_to_fast_tokenizer(args.model_name, args.revision, args.output_path) + + +def download_weights( + model_name: str, + revision: str | None = None, + token: str | None = None, + extension: str = ".safetensors", + auto_convert: bool | None = None, +) -> None: + if auto_convert is None: + auto_convert = True + + logger.info(extension) + meta_exts = [".json", ".py", ".model", ".md"] + + extensions = extension.split(",") + + if len(extensions) == 1 and extensions[0] not in meta_exts: + extensions.extend(meta_exts) + + files = hub.download_weights( + model_name, extensions, revision=revision, auth_token=token + ) + + if auto_convert and ".safetensors" in extensions: + if not hub.local_weight_files( + hub.get_model_path(model_name, revision), ".safetensors" + ): + if ".bin" not in extensions: + logger.info( + ".safetensors weights not found, \ + downloading pytorch weights to convert..." + ) + hub.download_weights( + model_name, ".bin", revision=revision, auth_token=token + ) + + logger.info( + ".safetensors weights not found, \ + converting from pytorch weights..." + ) + convert_to_safetensors(model_name, revision) + elif not any(f.endswith(".safetensors") for f in files): + logger.info( + ".safetensors weights not found on hub, \ + but were found locally. Remove them first to re-convert" + ) + if auto_convert: + convert_to_fast_tokenizer(model_name, revision) + + +def convert_to_safetensors( + model_name: str, + revision: str | None = None, +) -> None: + # Get local pytorch file paths + model_path = hub.get_model_path(model_name, revision) + local_pt_files = hub.local_weight_files(model_path, ".bin") + local_pt_index_files = hub.local_index_files(model_path, ".bin") + if len(local_pt_index_files) > 1: + logger.info( + "Found more than one .bin.index.json file: %s", local_pt_index_files + ) + return + if not local_pt_files: + logger.info("No pytorch .bin files found to convert") + return + + local_pt_files = [Path(f) for f in local_pt_files] + local_pt_index_file = local_pt_index_files[0] if local_pt_index_files else None + + # Safetensors final filenames + local_st_files = [ + p.parent / f"{p.stem.removeprefix('pytorch_')}.safetensors" + for p in local_pt_files + ] + + if any(Path.exists(p) for p in local_st_files): + logger.info( + "Existing .safetensors weights found, remove them first to reconvert" + ) + return + + try: + import transformers + + config = transformers.AutoConfig.from_pretrained( + model_name, + revision=revision, + ) + architecture = config.architectures[0] + + class_ = getattr(transformers, architecture) + + # Name for this variable depends on transformers version + discard_names = getattr(class_, "_tied_weights_keys", []) + discard_names.extend(getattr(class_, "_keys_to_ignore_on_load_missing", [])) + + except TypeError: + discard_names = [] + + if local_pt_index_file: + local_pt_index_file = Path(local_pt_index_file) + st_prefix = local_pt_index_file.stem.removeprefix("pytorch_").removesuffix( + ".bin.index" + ) + local_st_index_file = ( + local_pt_index_file.parent / f"{st_prefix}.safetensors.index.json" + ) + + if Path.exists(local_st_index_file): + logger.info( + "Existing .safetensors.index.json file found, remove it first to \ + reconvert" + ) + return + + hub.convert_index_file( + local_pt_index_file, local_st_index_file, local_pt_files, local_st_files + ) + + # Convert pytorch weights to safetensors + hub.convert_files(local_pt_files, local_st_files, discard_names) + + +def convert_to_fast_tokenizer( + model_name: str, + revision: str | None = None, + output_path: str | None = None, +) -> None: + # Check for existing "tokenizer.json" + model_path = hub.get_model_path(model_name, revision) + + if Path.exists(Path(model_path) / "tokenizer.json"): + logger.info("Model %s already has a fast tokenizer", model_name) + return + + if output_path is not None and not Path.isdir(output_path): + logger.info("Output path %s must exist and be a directory", output_path) + return + output_path = model_path + + import transformers + + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_name, revision=revision + ) + tokenizer.save_pretrained(output_path) + + logger.info("Saved tokenizer to %s", output_path) + + +def cli() -> None: + parser = FlexibleArgumentParser(description="vLLM CLI") + subparsers = parser.add_subparsers(required=True) + + download_weights_parser = subparsers.add_parser( + "download-weights", + help=("Download the weights of a given model"), + usage="model-util download-weights [options]", + ) + download_weights_parser.add_argument("model_name") + download_weights_parser.add_argument("--revision") + download_weights_parser.add_argument("--token") + download_weights_parser.add_argument("--extension", default=".safetensors") + download_weights_parser.add_argument("--auto_convert", default=True) + download_weights_parser.set_defaults( + dispatch_function=tgis_cli, command="download-weights" + ) + + convert_to_safetensors_parser = subparsers.add_parser( + "convert-to-safetensors", + help=("Convert model weights to safetensors"), + usage="model-util convert-to-safetensors [options]", + ) + convert_to_safetensors_parser.add_argument("model_name") + convert_to_safetensors_parser.add_argument("--revision") + convert_to_safetensors_parser.set_defaults( + dispatch_function=tgis_cli, command="convert-to-safetensors" + ) + + convert_to_fast_tokenizer_parser = subparsers.add_parser( + "convert-to-fast-tokenizer", + help=("Convert to fast tokenizer"), + usage="model-util convert-to-fast-tokenizer [options]", + ) + convert_to_fast_tokenizer_parser.add_argument("model_name") + convert_to_fast_tokenizer_parser.add_argument("--revision") + convert_to_fast_tokenizer_parser.add_argument("--output_path") + convert_to_fast_tokenizer_parser.set_defaults( + dispatch_function=tgis_cli, command="convert-to-fast-tokenizer" + ) + + args = parser.parse_args() + # One of the sub commands should be executed. + if hasattr(args, "dispatch_function"): + args.dispatch_function(args) + else: + parser.print_help() + + +if __name__ == "__main__": + cli() diff --git a/tests/test_hub.py b/tests/test_hub.py new file mode 100644 index 00000000..b64a2218 --- /dev/null +++ b/tests/test_hub.py @@ -0,0 +1,54 @@ +from pathlib import Path + +import pytest +from huggingface_hub.utils import LocalEntryNotFoundError + +from vllm_tgis_adapter.tgis_utils.hub import ( + convert_files, + download_weights, + weight_files, + weight_hub_files, +) + +pytestmark = pytest.mark.hf_data + + +def test_convert_files(): + model_id = "facebook/opt-125m" + local_pt_files = download_weights(model_id, extension=".bin") + local_pt_files = [Path(p) for p in local_pt_files] + local_st_files = [ + p.parent / f"{p.stem.removeprefix('pytorch_')}.safetensors" + for p in local_pt_files + ] + convert_files(local_pt_files, local_st_files, discard_names=[]) + + found_st_files = weight_files(model_id) + + assert all(str(p) in found_st_files for p in local_st_files) + + +def test_weight_hub_files(): + filenames = weight_hub_files("facebook/opt-125m") + assert filenames == ["model.safetensors"] + + +def test_weight_hub_files_llm(): + filenames = weight_hub_files("bigscience/bloom") + assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)] + + +def test_weight_hub_files_empty(): + filenames = weight_hub_files("bigscience/bloom", ".errors") + assert filenames == [] + + +def test_download_weights(): + files = download_weights("facebook/opt-125m") + local_files = weight_files("facebook/opt-125m") + assert files == local_files + + +def test_weight_files_error(): + with pytest.raises(LocalEntryNotFoundError): + weight_files("bert-base-uncased")