Skip to content

Commit

Permalink
Code review changes
Browse files Browse the repository at this point in the history
Signed-off-by: OmerD <[email protected]>
  • Loading branch information
omer-dayan committed Dec 19, 2024
1 parent fc19e86 commit 30af43e
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 10 deletions.
5 changes: 3 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def __init__(self,
f"'Please instead use `--hf-overrides '{hf_override!r}'`")
warnings.warn(DeprecationWarning(msg), stacklevel=2)

self.pull_model_tokenizer_for_s3(model, tokenizer)
self.maybe_pull_model_tokenizer_for_s3(model, tokenizer)

# The tokenizer version is consistent with the model version by default.
if tokenizer_revision is None:
Expand Down Expand Up @@ -360,7 +360,8 @@ def __init__(self,
self._verify_cuda_graph()
self._verify_bnb_config()

def pull_model_tokenizer_for_s3(self, model: str, tokenizer: str) -> None:
def maybe_pull_model_tokenizer_for_s3(self, model: str,
tokenizer: str) -> None:
if is_s3(model) or is_s3(tokenizer):
try:
from vllm.transformers_utils.s3_utils import S3Model
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1301,7 +1301,7 @@ def _prepare_weights(self, model_name_or_path: str,
model_name_or_path, index_file, self.load_config.download_dir,
revision)

if len(hf_weights_files) == 0:
if not hf_weights_files:
raise RuntimeError(
f"Cannot find any safetensors model weights with "
f"`{model_name_or_path}`")
Expand Down
52 changes: 45 additions & 7 deletions vllm/transformers_utils/s3_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from __future__ import annotations

import fnmatch
import os
import shutil
import signal
import tempfile
from pathlib import Path
from typing import Optional

import boto3

Expand All @@ -26,7 +25,16 @@ def _filter_ignore(paths: list[str], patterns: list[str]) -> list[str]:

def glob(s3=None,
path: str = "",
allow_pattern: list[str] | None = None) -> list[str]:
allow_pattern: Optional[list[str]] = None) -> list[str]:
"""
List full file names from S3 path and filter by allow pattern.
Args:
s3: S3 client to use.
path: The S3 path to list from.
allow_pattern: A list of patterns of which files to pull.
"""
if s3 is None:
s3 = boto3.client("s3")
bucket_name, _, paths = list_files(s3,
Expand All @@ -38,8 +46,19 @@ def glob(s3=None,
def list_files(
s3,
path: str,
allow_pattern: list[str] | None = None,
ignore_pattern: list[str] | None = None) -> tuple[str, str, list[str]]:
allow_pattern: Optional[list[str]] = None,
ignore_pattern: Optional[list[str]] = None
) -> tuple[str, str, list[str]]:
"""
List files from S3 path and filter by pattern.
Args:
s3: S3 client to use.
path: The S3 path to list from.
allow_pattern: A list of patterns of which files to pull.
ignore_pattern: A list of patterns of which files not to pull.
"""
parts = path.removeprefix('s3://').split('/')
prefix = '/'.join(parts[1:])
bucket_name = parts[0]
Expand All @@ -58,6 +77,16 @@ def list_files(


class S3Model:
"""
A class representing a S3 model mirrored into a temporary directory.
Attributes:
s3: S3 client.
dir: The temporary created directory.
Methods:
pull_files(): Pull model from S3 to the temporary directory.
"""

def __init__(self) -> None:
self.s3 = boto3.client('s3')
Expand All @@ -84,8 +113,17 @@ def new_handler(signum, frame):

def pull_files(self,
s3_model_path: str = "",
allow_pattern: list[str] | None = None,
ignore_pattern: list[str] | None = None) -> None:
allow_pattern: Optional[list[str]] = None,
ignore_pattern: Optional[list[str]] = None) -> None:
"""
Pull files from S3 storage into the temporary directory.
Args:
s3_model_path: The S3 path of the model.
allow_pattern: A list of patterns of which files to pull.
ignore_pattern: A list of patterns of which files not to pull.
"""
bucket_name, base_dir, files = list_files(self.s3, s3_model_path,
allow_pattern,
ignore_pattern)
Expand Down

0 comments on commit 30af43e

Please sign in to comment.