Skip to content

Commit

Permalink
Resolve merge conflicts
Browse files Browse the repository at this point in the history
Signed-off-by: Rafael Vasquez <[email protected]>
  • Loading branch information
rafvasq committed Dec 20, 2024
2 parents b6af3da + 995f562 commit aad6927
Show file tree
Hide file tree
Showing 13 changed files with 457 additions and 3 deletions.
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,9 @@ FROM vllm-base AS vllm-openai
# install additional dependencies for openai api server
RUN --mount=type=cache,target=/root/.cache/pip \
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.42.0' 'timm==0.9.10'; \
pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.42.0' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \
else \
pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.45.0' 'timm==0.9.10'; \
pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.45.0' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \
fi

ENV VLLM_USAGE_SOURCE production-docker-image
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ serving/distributed_serving
serving/metrics
serving/integrations
serving/tensorizer
serving/runai_model_streamer
```

```{toctree}
Expand Down
53 changes: 53 additions & 0 deletions docs/source/serving/runai_model_streamer.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
(runai-model-streamer)=

# Loading Models with Run:ai Model Streamer

Run:ai Model Streamer is a library to read tensors in concurrency, while streaming it to GPU memory.
Further reading can be found in [Run:ai Model Streamer Documentation](https://github.com/run-ai/runai-model-streamer/blob/master/docs/README.md).

vLLM supports loading weights in Safetensors format using the Run:ai Model Streamer.
You first need to install vLLM RunAI optional dependency:

```console
$ pip3 install vllm[runai]
```

To run it as an OpenAI-compatible server, add the `--load-format runai_streamer` flag:

```console
$ vllm serve /home/meta-llama/Llama-3.2-3B-Instruct --load-format runai_streamer
```

To run model from AWS S3 object store run:

```console
$ vllm serve s3://core-llm/Llama-3-8b --load-format runai_streamer
```

To run model from a S3 compatible object store run:

```console
$ RUNAI_STREAMER_S3_USE_VIRTUAL_ADDRESSING=0 AWS_EC2_METADATA_DISABLED=true AWS_ENDPOINT_URL=https://storage.googleapis.com vllm serve s3://core-llm/Llama-3-8b --load-format runai_streamer
```

## Tunable parameters

You can tune parameters using `--model-loader-extra-config`:

You can tune `concurrency` that controls the level of concurrency and number of OS threads reading tensors from the file to the CPU buffer.
For reading from S3, it will be the number of client instances the host is opening to the S3 server.

> ```console
> $ vllm serve /home/meta-llama/Llama-3.2-3B-Instruct --load-format runai_streamer --model-loader-extra-config '{"concurrency":16}'
> ```
You can controls the size of the CPU Memory buffer to which tensors are read from the file, and limit this size.
You can read further about CPU buffer memory limiting [here](https://github.com/run-ai/runai-model-streamer/blob/master/docs/src/env-vars.md#runai_streamer_memory_limit).
> ```console
> $ vllm serve /home/meta-llama/Llama-3.2-3B-Instruct --load-format runai_streamer --model-loader-extra-config '{"memory_limit":5368709120}'
> ```
```{note}
For further instructions about tunable parameters and additional parameters configurable through environment variables, read the [Environment Variables Documentation](https://github.com/run-ai/runai-model-streamer/blob/master/docs/src/env-vars.md).
```
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,7 @@ def _read_requirements(filename: str) -> List[str]:
ext_modules=ext_modules,
extras_require={
"tensorizer": ["tensorizer>=2.9.0"],
"runai": ["runai-model-streamer", "runai-model-streamer-s3", "boto3"],
"audio": ["librosa", "soundfile"], # Required for audio processing
"video": ["decord"] # Required for video processing
},
Expand Down
Empty file.
31 changes: 31 additions & 0 deletions tests/runai_model_streamer/test_runai_model_streamer_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from vllm import SamplingParams
from vllm.config import LoadConfig, LoadFormat
from vllm.model_executor.model_loader.loader import (RunaiModelStreamerLoader,
get_model_loader)

test_model = "openai-community/gpt2"

prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)


def get_runai_model_loader():
load_config = LoadConfig(load_format=LoadFormat.RUNAI_STREAMER)
return get_model_loader(load_config)


def test_get_model_loader_with_runai_flag():
model_loader = get_runai_model_loader()
assert isinstance(model_loader, RunaiModelStreamerLoader)


def test_runai_model_loader_download_files(vllm_runner):
with vllm_runner(test_model, load_format=LoadFormat.RUNAI_STREAMER) as llm:
deserialized_outputs = llm.generate(prompts, sampling_params)
assert deserialized_outputs
39 changes: 39 additions & 0 deletions tests/runai_model_streamer/test_weight_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import glob
import tempfile

import huggingface_hub.constants
import torch

from vllm.model_executor.model_loader.weight_utils import (
download_weights_from_hf, runai_safetensors_weights_iterator,
safetensors_weights_iterator)


def test_runai_model_loader():
with tempfile.TemporaryDirectory() as tmpdir:
huggingface_hub.constants.HF_HUB_OFFLINE = False
download_weights_from_hf("openai-community/gpt2",
allow_patterns=["*.safetensors"],
cache_dir=tmpdir)
safetensors = glob.glob(f"{tmpdir}/**/*.safetensors", recursive=True)
assert len(safetensors) > 0

runai_model_streamer_tensors = {}
hf_safetensors_tensors = {}

for name, tensor in runai_safetensors_weights_iterator(safetensors):
runai_model_streamer_tensors[name] = tensor

for name, tensor in safetensors_weights_iterator(safetensors):
hf_safetensors_tensors[name] = tensor

assert len(runai_model_streamer_tensors) == len(hf_safetensors_tensors)

for name, runai_tensor in runai_model_streamer_tensors.items():
assert runai_tensor.dtype == hf_safetensors_tensors[name].dtype
assert runai_tensor.shape == hf_safetensors_tensors[name].shape
assert torch.all(runai_tensor.eq(hf_safetensors_tensors[name]))


if __name__ == "__main__":
test_runai_model_loader()
37 changes: 37 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
try_get_generation_config, uses_mrope)
from vllm.transformers_utils.utils import is_s3
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
get_cpu_memory, print_warning_once, random_uuid,
resolve_obj_by_qualname)
Expand Down Expand Up @@ -256,6 +257,8 @@ def __init__(self,
f"'Please instead use `--hf-overrides '{hf_override!r}'`")
warnings.warn(DeprecationWarning(msg), stacklevel=2)

self.maybe_pull_model_tokenizer_for_s3(model, tokenizer)

# The tokenizer version is consistent with the model version by default.
if tokenizer_revision is None:
self.tokenizer_revision = revision
Expand Down Expand Up @@ -357,6 +360,39 @@ def __init__(self,
self._verify_cuda_graph()
self._verify_bnb_config()

def maybe_pull_model_tokenizer_for_s3(self, model: str,
tokenizer: str) -> None:
"""
Pull the model config or tokenizer to a temporary
directory in case of S3.
Args:
model: The model name or path.
tokenizer: The tokenizer name or path.
"""
if is_s3(model) or is_s3(tokenizer):
try:
from vllm.transformers_utils.s3_utils import S3Model
except ImportError as err:
raise ImportError(
"Please install Run:ai optional dependency "
"to use the S3 capabilities. "
"You can install it with: pip install vllm[runai]"
) from err

if is_s3(model):
self.s3_model = S3Model()
self.s3_model.pull_files(model, allow_pattern=["*config.json"])
self.model_weights = self.model
self.model = self.s3_model.dir

if is_s3(tokenizer):
self.s3_tokenizer = S3Model()
self.s3_tokenizer.pull_files(
model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
self.tokenizer = self.s3_tokenizer.dir

def _init_multimodal_config(
self, limit_mm_per_prompt: Optional[Mapping[str, int]]
) -> Optional["MultiModalConfig"]:
Expand Down Expand Up @@ -1099,6 +1135,7 @@ class LoadFormat(str, enum.Enum):
GGUF = "gguf"
BITSANDBYTES = "bitsandbytes"
MISTRAL = "mistral"
RUNAI_STREAMER = "runai_streamer"


@dataclass
Expand Down
2 changes: 2 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'* "tensorizer" will load the weights using tensorizer from '
'CoreWeave. See the Tensorize vLLM Model script in the Examples '
'section for more information.\n'
'* "runai_streamer" will load the Safetensors weights using Run:ai'
'Model Streamer \n'
'* "bitsandbytes" will load the weights using bitsandbytes '
'quantization.\n')
parser.add_argument(
Expand Down
118 changes: 117 additions & 1 deletion vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
get_gguf_extra_tensor_names, gguf_quant_weights_iterator,
initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator,
safetensors_weights_iterator)
runai_safetensors_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.transformers_utils.utils import is_s3
from vllm.utils import is_pin_memory_available


Expand Down Expand Up @@ -1234,6 +1235,118 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
return model


class RunaiModelStreamerLoader(BaseModelLoader):
"""
Model loader that can load safetensors
files from local FS or S3 bucket.
"""

def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
extra_config = load_config.model_loader_extra_config

if ("concurrency" in extra_config
and isinstance(extra_config.get("concurrency"), int)):
os.environ["RUNAI_STREAMER_CONCURRENCY"] = str(
extra_config.get("concurrency"))

if ("memory_limit" in extra_config
and isinstance(extra_config.get("memory_limit"), int)):
os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str(
extra_config.get("memory_limit"))

runai_streamer_s3_endpoint = os.getenv(
'RUNAI_STREAMER_S3_ENDPOINT')
aws_endpoint_url = os.getenv('AWS_ENDPOINT_URL')
if (runai_streamer_s3_endpoint is None
and aws_endpoint_url is not None):
os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url

def _prepare_weights(self, model_name_or_path: str,
revision: Optional[str]) -> List[str]:
"""Prepare weights for the model.
If the model is not local, it will be downloaded."""
is_s3_path = is_s3(model_name_or_path)
if is_s3_path:
try:
from vllm.transformers_utils.s3_utils import glob as s3_glob
except ImportError as err:
raise ImportError(
"Please install Run:ai optional dependency "
"to use the S3 capabilities. "
"You can install it with: pip install vllm[runai]"
) from err

is_local = os.path.isdir(model_name_or_path)
safetensors_pattern = "*.safetensors"
index_file = SAFE_WEIGHTS_INDEX_NAME

hf_folder = (model_name_or_path if
(is_local or is_s3_path) else download_weights_from_hf(
model_name_or_path,
self.load_config.download_dir,
[safetensors_pattern],
revision,
ignore_patterns=self.load_config.ignore_patterns,
))

if is_s3_path:
hf_weights_files = s3_glob(path=hf_folder,
allow_pattern=[safetensors_pattern])
else:
hf_weights_files = glob.glob(
os.path.join(hf_folder, safetensors_pattern))

if not is_local and not is_s3_path:
download_safetensors_index_file_from_hf(
model_name_or_path, index_file, self.load_config.download_dir,
revision)

if not hf_weights_files:
raise RuntimeError(
f"Cannot find any safetensors model weights with "
f"`{model_name_or_path}`")

return hf_weights_files

def _get_weights_iterator(
self, model_or_path: str,
revision: str) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format."""
hf_weights_files = self._prepare_weights(model_or_path, revision)
return runai_safetensors_weights_iterator(hf_weights_files)

def download_model(self, model_config: ModelConfig) -> None:
"""Download model if necessary"""
self._prepare_weights(model_config.model, model_config.revision)

def load_model(self, vllm_config: VllmConfig) -> nn.Module:
"""Perform streaming of the model to destination"""
device_config = vllm_config.device_config
model_config = vllm_config.model_config

target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = _initialize_model(vllm_config=vllm_config)

model_weights = model_config.model
if hasattr(model_config, "model_weights"):
model_weights = model_config.model_weights
model.load_weights(
self._get_weights_iterator(model_weights,
model_config.revision))

for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
return model.eval()


def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
"""Get a model loader based on the load format."""

Expand All @@ -1255,4 +1368,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
if load_config.load_format == LoadFormat.GGUF:
return GGUFModelLoader(load_config)

if load_config.load_format == LoadFormat.RUNAI_STREAMER:
return RunaiModelStreamerLoader(load_config)

return DefaultModelLoader(load_config)
24 changes: 24 additions & 0 deletions vllm/model_executor/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,30 @@ def safetensors_weights_iterator(
yield name, param


def runai_safetensors_weights_iterator(
hf_weights_files: List[str]
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files."""
try:
from runai_model_streamer import SafetensorsStreamer
except ImportError as err:
raise ImportError(
"Please install Run:ai optional dependency."
"You can install it with: pip install vllm[runai]") from err

enable_tqdm = not torch.distributed.is_initialized(
) or torch.distributed.get_rank() == 0
with SafetensorsStreamer() as streamer:
for st_file in tqdm(
hf_weights_files,
desc="Loading safetensors using Runai Model Streamer",
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
):
streamer.stream_file(st_file)
yield from streamer.get_tensors()


def pt_weights_iterator(
hf_weights_files: List[str]
) -> Generator[Tuple[str, torch.Tensor], None, None]:
Expand Down
Loading

0 comments on commit aad6927

Please sign in to comment.