From 995f56236bc08300ea11fc8cd3d66029ffec8678 Mon Sep 17 00:00:00 2001 From: omer-dayan Date: Fri, 20 Dec 2024 18:46:24 +0200 Subject: [PATCH] [Core] Loading model from S3 using RunAI Model Streamer as optional loader (#10192) Signed-off-by: OmerD --- Dockerfile | 4 +- docs/source/index.rst | 1 + docs/source/serving/runai_model_streamer.rst | 53 +++++++ setup.py | 1 + tests/runai_model_streamer/__init__.py | 0 .../test_runai_model_streamer_loader.py | 31 ++++ .../runai_model_streamer/test_weight_utils.py | 39 +++++ vllm/config.py | 37 +++++ vllm/engine/arg_utils.py | 2 + vllm/model_executor/model_loader/loader.py | 118 +++++++++++++- .../model_loader/weight_utils.py | 24 +++ vllm/transformers_utils/s3_utils.py | 146 ++++++++++++++++++ vllm/transformers_utils/utils.py | 4 + 13 files changed, 457 insertions(+), 3 deletions(-) create mode 100644 docs/source/serving/runai_model_streamer.rst create mode 100644 tests/runai_model_streamer/__init__.py create mode 100644 tests/runai_model_streamer/test_runai_model_streamer_loader.py create mode 100644 tests/runai_model_streamer/test_weight_utils.py create mode 100644 vllm/transformers_utils/s3_utils.py diff --git a/Dockerfile b/Dockerfile index 0944050f7dfca..84350cde59bfb 100644 --- a/Dockerfile +++ b/Dockerfile @@ -240,9 +240,9 @@ FROM vllm-base AS vllm-openai # install additional dependencies for openai api server RUN --mount=type=cache,target=/root/.cache/pip \ if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \ - pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.42.0' 'timm==0.9.10'; \ + pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.42.0' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \ else \ - pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.45.0' 'timm==0.9.10'; \ + pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.45.0' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \ fi ENV VLLM_USAGE_SOURCE production-docker-image diff --git a/docs/source/index.rst b/docs/source/index.rst index fd741ea5e9766..d812885aafea9 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -88,6 +88,7 @@ Documentation serving/metrics serving/integrations serving/tensorizer + serving/runai_model_streamer .. toctree:: :maxdepth: 1 diff --git a/docs/source/serving/runai_model_streamer.rst b/docs/source/serving/runai_model_streamer.rst new file mode 100644 index 0000000000000..459eb8677fb95 --- /dev/null +++ b/docs/source/serving/runai_model_streamer.rst @@ -0,0 +1,53 @@ +.. _runai_model_streamer: + +Loading Models with Run:ai Model Streamer +========================================= +Run:ai Model Streamer is a library to read tensors in concurrency, while streaming it to GPU memory. +Further reading can be found in `Run:ai Model Streamer Documentation `_. + +vLLM supports loading weights in Safetensors format using the Run:ai Model Streamer. +You first need to install vLLM RunAI optional dependency: + +.. code-block:: console + + $ pip3 install vllm[runai] + +To run it as an OpenAI-compatible server, add the `--load-format runai_streamer` flag: + +.. code-block:: console + + $ vllm serve /home/meta-llama/Llama-3.2-3B-Instruct --load-format runai_streamer + +To run model from AWS S3 object store run: + +.. code-block:: console + + $ vllm serve s3://core-llm/Llama-3-8b --load-format runai_streamer + + +To run model from a S3 compatible object store run: + +.. code-block:: console + + $ RUNAI_STREAMER_S3_USE_VIRTUAL_ADDRESSING=0 AWS_EC2_METADATA_DISABLED=true AWS_ENDPOINT_URL=https://storage.googleapis.com vllm serve s3://core-llm/Llama-3-8b --load-format runai_streamer + +Tunable parameters +------------------ +You can tune parameters using `--model-loader-extra-config`: + +You can tune `concurrency` that controls the level of concurrency and number of OS threads reading tensors from the file to the CPU buffer. +For reading from S3, it will be the number of client instances the host is opening to the S3 server. + + .. code-block:: console + + $ vllm serve /home/meta-llama/Llama-3.2-3B-Instruct --load-format runai_streamer --model-loader-extra-config '{"concurrency":16}' + +You can controls the size of the CPU Memory buffer to which tensors are read from the file, and limit this size. +You can read further about CPU buffer memory limiting `here `_. + + .. code-block:: console + + $ vllm serve /home/meta-llama/Llama-3.2-3B-Instruct --load-format runai_streamer --model-loader-extra-config '{"memory_limit":5368709120}' + +.. note:: + For further instructions about tunable parameters and additional parameters configurable through environment variables, read the `Environment Variables Documentation `_. diff --git a/setup.py b/setup.py index a860093fe5f35..73407b64edf22 100644 --- a/setup.py +++ b/setup.py @@ -630,6 +630,7 @@ def _read_requirements(filename: str) -> List[str]: ext_modules=ext_modules, extras_require={ "tensorizer": ["tensorizer>=2.9.0"], + "runai": ["runai-model-streamer", "runai-model-streamer-s3", "boto3"], "audio": ["librosa", "soundfile"], # Required for audio processing "video": ["decord"] # Required for video processing }, diff --git a/tests/runai_model_streamer/__init__.py b/tests/runai_model_streamer/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/runai_model_streamer/test_runai_model_streamer_loader.py b/tests/runai_model_streamer/test_runai_model_streamer_loader.py new file mode 100644 index 0000000000000..c5722fbae5c8a --- /dev/null +++ b/tests/runai_model_streamer/test_runai_model_streamer_loader.py @@ -0,0 +1,31 @@ +from vllm import SamplingParams +from vllm.config import LoadConfig, LoadFormat +from vllm.model_executor.model_loader.loader import (RunaiModelStreamerLoader, + get_model_loader) + +test_model = "openai-community/gpt2" + +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0) + + +def get_runai_model_loader(): + load_config = LoadConfig(load_format=LoadFormat.RUNAI_STREAMER) + return get_model_loader(load_config) + + +def test_get_model_loader_with_runai_flag(): + model_loader = get_runai_model_loader() + assert isinstance(model_loader, RunaiModelStreamerLoader) + + +def test_runai_model_loader_download_files(vllm_runner): + with vllm_runner(test_model, load_format=LoadFormat.RUNAI_STREAMER) as llm: + deserialized_outputs = llm.generate(prompts, sampling_params) + assert deserialized_outputs diff --git a/tests/runai_model_streamer/test_weight_utils.py b/tests/runai_model_streamer/test_weight_utils.py new file mode 100644 index 0000000000000..5c89bd78ad81d --- /dev/null +++ b/tests/runai_model_streamer/test_weight_utils.py @@ -0,0 +1,39 @@ +import glob +import tempfile + +import huggingface_hub.constants +import torch + +from vllm.model_executor.model_loader.weight_utils import ( + download_weights_from_hf, runai_safetensors_weights_iterator, + safetensors_weights_iterator) + + +def test_runai_model_loader(): + with tempfile.TemporaryDirectory() as tmpdir: + huggingface_hub.constants.HF_HUB_OFFLINE = False + download_weights_from_hf("openai-community/gpt2", + allow_patterns=["*.safetensors"], + cache_dir=tmpdir) + safetensors = glob.glob(f"{tmpdir}/**/*.safetensors", recursive=True) + assert len(safetensors) > 0 + + runai_model_streamer_tensors = {} + hf_safetensors_tensors = {} + + for name, tensor in runai_safetensors_weights_iterator(safetensors): + runai_model_streamer_tensors[name] = tensor + + for name, tensor in safetensors_weights_iterator(safetensors): + hf_safetensors_tensors[name] = tensor + + assert len(runai_model_streamer_tensors) == len(hf_safetensors_tensors) + + for name, runai_tensor in runai_model_streamer_tensors.items(): + assert runai_tensor.dtype == hf_safetensors_tensors[name].dtype + assert runai_tensor.shape == hf_safetensors_tensors[name].shape + assert torch.all(runai_tensor.eq(hf_safetensors_tensors[name])) + + +if __name__ == "__main__": + test_runai_model_loader() diff --git a/vllm/config.py b/vllm/config.py index 6badae24d9d7d..643698f8bbec3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -29,6 +29,7 @@ get_hf_text_config, get_pooling_config, get_sentence_transformer_tokenizer_config, is_encoder_decoder, try_get_generation_config, uses_mrope) +from vllm.transformers_utils.utils import is_s3 from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless, get_cpu_memory, print_warning_once, random_uuid, resolve_obj_by_qualname) @@ -256,6 +257,8 @@ def __init__(self, f"'Please instead use `--hf-overrides '{hf_override!r}'`") warnings.warn(DeprecationWarning(msg), stacklevel=2) + self.maybe_pull_model_tokenizer_for_s3(model, tokenizer) + # The tokenizer version is consistent with the model version by default. if tokenizer_revision is None: self.tokenizer_revision = revision @@ -357,6 +360,39 @@ def __init__(self, self._verify_cuda_graph() self._verify_bnb_config() + def maybe_pull_model_tokenizer_for_s3(self, model: str, + tokenizer: str) -> None: + """ + Pull the model config or tokenizer to a temporary + directory in case of S3. + + Args: + model: The model name or path. + tokenizer: The tokenizer name or path. + + """ + if is_s3(model) or is_s3(tokenizer): + try: + from vllm.transformers_utils.s3_utils import S3Model + except ImportError as err: + raise ImportError( + "Please install Run:ai optional dependency " + "to use the S3 capabilities. " + "You can install it with: pip install vllm[runai]" + ) from err + + if is_s3(model): + self.s3_model = S3Model() + self.s3_model.pull_files(model, allow_pattern=["*config.json"]) + self.model_weights = self.model + self.model = self.s3_model.dir + + if is_s3(tokenizer): + self.s3_tokenizer = S3Model() + self.s3_tokenizer.pull_files( + model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) + self.tokenizer = self.s3_tokenizer.dir + def _init_multimodal_config( self, limit_mm_per_prompt: Optional[Mapping[str, int]] ) -> Optional["MultiModalConfig"]: @@ -1099,6 +1135,7 @@ class LoadFormat(str, enum.Enum): GGUF = "gguf" BITSANDBYTES = "bitsandbytes" MISTRAL = "mistral" + RUNAI_STREAMER = "runai_streamer" @dataclass diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 912a8b2f54adb..7aa45b7958e26 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -316,6 +316,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: '* "tensorizer" will load the weights using tensorizer from ' 'CoreWeave. See the Tensorize vLLM Model script in the Examples ' 'section for more information.\n' + '* "runai_streamer" will load the Safetensors weights using Run:ai' + 'Model Streamer \n' '* "bitsandbytes" will load the weights using bitsandbytes ' 'quantization.\n') parser.add_argument( diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index fdc4c6305bd5e..24e554e6060ab 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -45,9 +45,10 @@ filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, get_gguf_extra_tensor_names, gguf_quant_weights_iterator, initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator, - safetensors_weights_iterator) + runai_safetensors_weights_iterator, safetensors_weights_iterator) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform +from vllm.transformers_utils.utils import is_s3 from vllm.utils import is_pin_memory_available @@ -1234,6 +1235,118 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module: return model +class RunaiModelStreamerLoader(BaseModelLoader): + """ + Model loader that can load safetensors + files from local FS or S3 bucket. + """ + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + extra_config = load_config.model_loader_extra_config + + if ("concurrency" in extra_config + and isinstance(extra_config.get("concurrency"), int)): + os.environ["RUNAI_STREAMER_CONCURRENCY"] = str( + extra_config.get("concurrency")) + + if ("memory_limit" in extra_config + and isinstance(extra_config.get("memory_limit"), int)): + os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str( + extra_config.get("memory_limit")) + + runai_streamer_s3_endpoint = os.getenv( + 'RUNAI_STREAMER_S3_ENDPOINT') + aws_endpoint_url = os.getenv('AWS_ENDPOINT_URL') + if (runai_streamer_s3_endpoint is None + and aws_endpoint_url is not None): + os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url + + def _prepare_weights(self, model_name_or_path: str, + revision: Optional[str]) -> List[str]: + """Prepare weights for the model. + + If the model is not local, it will be downloaded.""" + is_s3_path = is_s3(model_name_or_path) + if is_s3_path: + try: + from vllm.transformers_utils.s3_utils import glob as s3_glob + except ImportError as err: + raise ImportError( + "Please install Run:ai optional dependency " + "to use the S3 capabilities. " + "You can install it with: pip install vllm[runai]" + ) from err + + is_local = os.path.isdir(model_name_or_path) + safetensors_pattern = "*.safetensors" + index_file = SAFE_WEIGHTS_INDEX_NAME + + hf_folder = (model_name_or_path if + (is_local or is_s3_path) else download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + [safetensors_pattern], + revision, + ignore_patterns=self.load_config.ignore_patterns, + )) + + if is_s3_path: + hf_weights_files = s3_glob(path=hf_folder, + allow_pattern=[safetensors_pattern]) + else: + hf_weights_files = glob.glob( + os.path.join(hf_folder, safetensors_pattern)) + + if not is_local and not is_s3_path: + download_safetensors_index_file_from_hf( + model_name_or_path, index_file, self.load_config.download_dir, + revision) + + if not hf_weights_files: + raise RuntimeError( + f"Cannot find any safetensors model weights with " + f"`{model_name_or_path}`") + + return hf_weights_files + + def _get_weights_iterator( + self, model_or_path: str, + revision: str) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Get an iterator for the model weights based on the load format.""" + hf_weights_files = self._prepare_weights(model_or_path, revision) + return runai_safetensors_weights_iterator(hf_weights_files) + + def download_model(self, model_config: ModelConfig) -> None: + """Download model if necessary""" + self._prepare_weights(model_config.model, model_config.revision) + + def load_model(self, vllm_config: VllmConfig) -> nn.Module: + """Perform streaming of the model to destination""" + device_config = vllm_config.device_config + model_config = vllm_config.model_config + + target_device = torch.device(device_config.device) + with set_default_torch_dtype(model_config.dtype): + with target_device: + model = _initialize_model(vllm_config=vllm_config) + + model_weights = model_config.model + if hasattr(model_config, "model_weights"): + model_weights = model_config.model_weights + model.load_weights( + self._get_weights_iterator(model_weights, + model_config.revision)) + + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + with device_loading_context(module, target_device): + quant_method.process_weights_after_loading(module) + return model.eval() + + def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: """Get a model loader based on the load format.""" @@ -1255,4 +1368,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: if load_config.load_format == LoadFormat.GGUF: return GGUFModelLoader(load_config) + if load_config.load_format == LoadFormat.RUNAI_STREAMER: + return RunaiModelStreamerLoader(load_config) + return DefaultModelLoader(load_config) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 9488d54edf365..f2a9e7e2687cb 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -410,6 +410,30 @@ def safetensors_weights_iterator( yield name, param +def runai_safetensors_weights_iterator( + hf_weights_files: List[str] +) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Iterate over the weights in the model safetensor files.""" + try: + from runai_model_streamer import SafetensorsStreamer + except ImportError as err: + raise ImportError( + "Please install Run:ai optional dependency." + "You can install it with: pip install vllm[runai]") from err + + enable_tqdm = not torch.distributed.is_initialized( + ) or torch.distributed.get_rank() == 0 + with SafetensorsStreamer() as streamer: + for st_file in tqdm( + hf_weights_files, + desc="Loading safetensors using Runai Model Streamer", + disable=not enable_tqdm, + bar_format=_BAR_FORMAT, + ): + streamer.stream_file(st_file) + yield from streamer.get_tensors() + + def pt_weights_iterator( hf_weights_files: List[str] ) -> Generator[Tuple[str, torch.Tensor], None, None]: diff --git a/vllm/transformers_utils/s3_utils.py b/vllm/transformers_utils/s3_utils.py new file mode 100644 index 0000000000000..6f63dab74d696 --- /dev/null +++ b/vllm/transformers_utils/s3_utils.py @@ -0,0 +1,146 @@ +import fnmatch +import os +import shutil +import signal +import tempfile +from pathlib import Path +from typing import Optional + +import boto3 + + +def _filter_allow(paths: list[str], patterns: list[str]) -> list[str]: + return [ + path for path in paths if any( + fnmatch.fnmatch(path, pattern) for pattern in patterns) + ] + + +def _filter_ignore(paths: list[str], patterns: list[str]) -> list[str]: + return [ + path for path in paths + if not any(fnmatch.fnmatch(path, pattern) for pattern in patterns) + ] + + +def glob(s3=None, + path: str = "", + allow_pattern: Optional[list[str]] = None) -> list[str]: + """ + List full file names from S3 path and filter by allow pattern. + + Args: + s3: S3 client to use. + path: The S3 path to list from. + allow_pattern: A list of patterns of which files to pull. + + Returns: + list[str]: List of full S3 paths allowed by the pattern + """ + if s3 is None: + s3 = boto3.client("s3") + bucket_name, _, paths = list_files(s3, + path=path, + allow_pattern=allow_pattern) + return [f"s3://{bucket_name}/{path}" for path in paths] + + +def list_files( + s3, + path: str, + allow_pattern: Optional[list[str]] = None, + ignore_pattern: Optional[list[str]] = None +) -> tuple[str, str, list[str]]: + """ + List files from S3 path and filter by pattern. + + Args: + s3: S3 client to use. + path: The S3 path to list from. + allow_pattern: A list of patterns of which files to pull. + ignore_pattern: A list of patterns of which files not to pull. + + Returns: + tuple[str, str, list[str]]: A tuple where: + - The first element is the bucket name + - The second element is string represent the bucket + and the prefix as a dir like string + - The third element is a list of files allowed or + disallowed by pattern + """ + parts = path.removeprefix('s3://').split('/') + prefix = '/'.join(parts[1:]) + bucket_name = parts[0] + + objects = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix) + paths = [obj['Key'] for obj in objects.get('Contents', [])] + + paths = _filter_ignore(paths, ["*/"]) + if allow_pattern is not None: + paths = _filter_allow(paths, allow_pattern) + + if ignore_pattern is not None: + paths = _filter_ignore(paths, ignore_pattern) + + return bucket_name, prefix, paths + + +class S3Model: + """ + A class representing a S3 model mirrored into a temporary directory. + + Attributes: + s3: S3 client. + dir: The temporary created directory. + + Methods: + pull_files(): Pull model from S3 to the temporary directory. + """ + + def __init__(self) -> None: + self.s3 = boto3.client('s3') + for sig in (signal.SIGINT, signal.SIGTERM): + existing_handler = signal.getsignal(sig) + signal.signal(sig, self._close_by_signal(existing_handler)) + self.dir = tempfile.mkdtemp() + + def __del__(self): + self._close() + + def _close(self) -> None: + if os.path.exists(self.dir): + shutil.rmtree(self.dir) + + def _close_by_signal(self, existing_handler=None): + + def new_handler(signum, frame): + self._close() + if existing_handler: + existing_handler(signum, frame) + + return new_handler + + def pull_files(self, + s3_model_path: str = "", + allow_pattern: Optional[list[str]] = None, + ignore_pattern: Optional[list[str]] = None) -> None: + """ + Pull files from S3 storage into the temporary directory. + + Args: + s3_model_path: The S3 path of the model. + allow_pattern: A list of patterns of which files to pull. + ignore_pattern: A list of patterns of which files not to pull. + + """ + bucket_name, base_dir, files = list_files(self.s3, s3_model_path, + allow_pattern, + ignore_pattern) + if len(files) == 0: + return + + for file in files: + destination_file = self.dir + file.removeprefix(base_dir) + local_dir = Path(destination_file).parent + os.makedirs(local_dir, exist_ok=True) + self.s3.download_file(bucket_name, file, destination_file) diff --git a/vllm/transformers_utils/utils.py b/vllm/transformers_utils/utils.py index 7a9041b04fbb9..10a09fb4f566c 100644 --- a/vllm/transformers_utils/utils.py +++ b/vllm/transformers_utils/utils.py @@ -3,6 +3,10 @@ from typing import Union +def is_s3(model_or_path: str) -> bool: + return model_or_path.lower().startswith('s3://') + + def check_gguf_file(model: Union[str, PathLike]) -> bool: """Check if the file is a GGUF model.""" model = Path(model)