Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] add stream loader into model loader #11

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,7 @@ def _read_requirements(filename: str) -> List[str]:
install_requires=get_requirements(),
ext_modules=ext_modules,
extras_require={
"stream": ["boto3>=1.35.5"],
"tensorizer": ["tensorizer>=2.9.0"],
"video": ["opencv-python"], # Required for video processing
"audio": ["librosa", "soundfile"] # Required for audio processing
Expand Down
1 change: 1 addition & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,7 @@ class LoadFormat(str, enum.Enum):
GGUF = "gguf"
BITSANDBYTES = "bitsandbytes"
MISTRAL = "mistral"
STREAM = "stream"


@dataclass
Expand Down
29 changes: 28 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import dataclasses
import json
from dataclasses import dataclass
from pathlib import Path
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple,
Type, Union)

Expand Down Expand Up @@ -258,7 +259,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'CoreWeave. See the Tensorize vLLM Model script in the Examples '
'section for more information.\n'
'* "bitsandbytes" will load the weights using bitsandbytes '
'quantization.\n')
'quantization.\n'
'* "stream" will load the weights using stream from remote storage,'
'like S3 or TOS.\n')
parser.add_argument(
'--config-format',
default=EngineArgs.config_format,
Expand Down Expand Up @@ -796,6 +799,30 @@ def from_cli_args(cls, args: argparse.Namespace):
return engine_args

def create_model_config(self) -> ModelConfig:
if self.load_format == "stream":
from vllm.model_executor.model_loader.stream_loader import (
StreamConfig)

# download config json to `download_dir`
# and replace `model` with `download_dir`
model_loader_extra_config = self.model_loader_extra_config or {}
if isinstance(model_loader_extra_config, str):
model_loader_extra_config = json.loads(
model_loader_extra_config)

stream_config = StreamConfig(**model_loader_extra_config)
stream_model = stream_config.construct_stream_model()

config_dir = self.download_dir or str(
Path("/tmp/stream_load/").joinpath(self.model))
config_path = stream_model.download_config(config_dir)

reset_tokenizer = self.model == self.tokenizer
self.served_model_name = self.served_model_name or self.model
self.model = str(config_path)
if reset_tokenizer:
self.tokenizer = self.model

return ModelConfig(
model=self.model,
tokenizer=self.tokenizer,
Expand Down
44 changes: 44 additions & 0 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.model_loader.stream_loader import StreamConfig
from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig, is_vllm_tensorized, load_with_tensorizer,
serialize_vllm_model, tensorizer_weights_iterator)
Expand Down Expand Up @@ -1164,6 +1165,46 @@ def load_model(self, *, model_config: ModelConfig,
return model


class StreamModelLoader(BaseModelLoader):
"""Model loader using AiBrix's stream loader library."""

def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if isinstance(load_config.model_loader_extra_config, StreamConfig):
self.stream_loader_config = load_config.model_loader_extra_config
else:
self.stream_loader_config = StreamConfig(
**load_config.model_loader_extra_config)

self.stream_model = self.stream_loader_config.construct_stream_model()

def _verify_config(self, model_config: ModelConfig,
parallel_config: ParallelConfig):
self.stream_loader_config.verify_with_model_config(model_config)
self.stream_loader_config.verify_with_parallel_config(parallel_config)

def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
self._verify_config(model_config, parallel_config)

with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, cache_config,
scheduler_config)

model.load_weights(self.stream_model.get_weights_iterator("cpu"))

return model.eval()

def download_model(self, model_config: ModelConfig) -> None:
pass # Nothing to download


def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
"""Get a model loader based on the load format."""

Expand All @@ -1185,4 +1226,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
if load_config.load_format == LoadFormat.GGUF:
return GGUFModelLoader(load_config)

if load_config.load_format == LoadFormat.STREAM:
return StreamModelLoader(load_config)

return DefaultModelLoader(load_config)
16 changes: 16 additions & 0 deletions vllm/model_executor/model_loader/stream/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2024 The Aibrix Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

SUPPORTED_STREAM_STORAGE = ("s3", "tos", "local")
DOWNLOAD_CACHE_DIR = ".cache"
210 changes: 210 additions & 0 deletions vllm/model_executor/model_loader/stream/file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
# Copyright 2024 The Aibrix Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from io import BytesIO
from pathlib import Path
from typing import Optional, Union

import boto3
import numpy as np
from boto3.s3.transfer import TransferConfig

from vllm.logger import init_logger

from .utils import (_create_s3_client, _parse_bucket_info_from_uri, meta_file,
need_to_download, read_to_bytes_io, save_meta_data)

logger = init_logger(__name__)


class LoadFile:

def __init__(self, file_source: str) -> None:
self.file_source = file_source

def load_whole_file(self, num_threads: int = 1):
raise NotImplementedError

def load_to_bytes(self, offset: int, count: int) -> BytesIO:
raise NotImplementedError

def load_to_buffer(self, offset: int, count: int) -> memoryview:
raise NotImplementedError

def download(self, target_dir):
raise NotImplementedError


class LocalFile(LoadFile):

def __init__(self, file: Union[str, Path]) -> None:
if not Path(file).exists():
raise ValueError(f"file {file} not exist")

self.file = file
super().__init__(file_source="local")

def load_whole_file(self, num_threads: int = 1):
if num_threads != 1:
logger.warning("num_threads %s is not supported for local file.",
num_threads)

tensor_bytes = np.memmap(
self.file,
dtype=np.uint8,
mode="c",
)
return tensor_bytes.tobytes()

def load_to_bytes(self, offset: int, count: int):
return BytesIO(self.load_to_buffer(offset=offset, count=count))

def load_to_buffer(self, offset: int, count: int):
return np.memmap(
self.file,
dtype=np.uint8,
mode="r",
offset=offset,
shape=count,
)


class RemoteFile(LoadFile):

def __init__(self, file: str, file_source: str) -> None:
self.file = file
super().__init__(file_source=file_source)

def load_to_buffer(self, offset: int, count: int):
tensor_bytes = self.load_to_bytes(offset=offset, count=count)
return tensor_bytes.getbuffer()

def download_file(self, target_dir: str, num_threads: int = 1):
raise NotImplementedError


class S3File(RemoteFile):

def __init__(
self,
scheme: str,
bucket_name: str,
bucket_path: str,
s3_client: Optional[boto3.client] = None,
s3_access_key_id: Optional[str] = None,
s3_secret_access_key: Optional[str] = None,
s3_region: Optional[str] = None,
s3_endpinit: Optional[str] = None,
) -> None:
self.bucket_name = bucket_name
self.bucket_path = bucket_path
if s3_client is None:
try:
s3_client = _create_s3_client(
ak=s3_access_key_id,
sk=s3_secret_access_key,
endpoint=s3_endpinit,
region=s3_region,
)
except Exception as e:
raise ValueError(f"create s3 client failed for {e}.") from e
self.s3_client = s3_client
try:
self.s3_client.head_object(Bucket=bucket_name, Key=bucket_path)
except Exception as e:
raise ValueError(
f"S3 bucket path {bucket_path} not exist for {e}.") from e

file = scheme + "://" + bucket_name + "/" + bucket_path
super().__init__(file=file, file_source=scheme)

@classmethod
def from_uri(cls, file_uri: str, **kwargs):
scheme, bucket_name, bucket_path = _parse_bucket_info_from_uri(
file_uri)
cls(scheme, bucket_name, bucket_path, **kwargs)

def load_whole_file(self, num_threads: int = 1):
config_kwargs = {
"max_concurrency": num_threads,
"use_threads": True,
}
config = TransferConfig(**config_kwargs)

data = BytesIO()
self.s3_client.download_fileobj(
Bucket=self.bucket_name,
Key=self.bucket_path,
Fileobj=data,
Config=config,
)
return data.getbuffer()

def load_to_bytes(self, offset: int, count: int):
range_header = f"bytes={offset}-{offset+count-1}"
resp = self.s3_client.get_object(Bucket=self.bucket_name,
Key=self.bucket_path,
Range=range_header)
return read_to_bytes_io(resp.get("Body"))

def download_file(
self,
target_dir: str,
num_threads: int = 1,
force_download: bool = False,
):
try:
meta_data = self.s3_client.head_object(Bucket=self.bucket_name,
Key=self.bucket_path)
except Exception as e:
raise ValueError("S3 bucket path %s not exist for %s.",
self.bucket_path, e) from e

# ensure target dir exist
target_path = Path(target_dir)
target_path.mkdir(parents=True, exist_ok=True)

_file_name = self.bucket_path.split("/")[-1]
local_file = target_path.joinpath(_file_name).absolute()

# check if file exist
etag = meta_data.get("ETag", "")
file_size = meta_data.get("ContentLength", 0)

meta_data_file = meta_file(local_path=target_path,
file_name=_file_name)
if not need_to_download(local_file, meta_data_file, file_size, etag,
force_download):
logger.info("file `%s` already exist.", self.bucket_path)
return

config_kwargs = {
"max_concurrency": num_threads,
"use_threads": True,
}
config = TransferConfig(**config_kwargs)
self.s3_client.download_file(
Bucket=self.bucket_name,
Key=self.bucket_path,
Filename=str(
local_file
), # S3 client does not support Path, convert it to str
Config=config,
)
save_meta_data(meta_data_file, etag)
logger.info(
"download file from `%s` to `%s` success.",
self.bucket_path,
target_dir,
)
Loading
Loading