Skip to content

Commit

Permalink
support ignore patterns in model loader (#6673)
Browse files Browse the repository at this point in the history
  • Loading branch information
simon-mo authored Jul 23, 2024
1 parent 22fa2e3 commit 3eda4ec
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 10 deletions.
15 changes: 14 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,12 +599,16 @@ class LoadConfig:
mainly for profiling.
"tensorizer" will use CoreWeave's tensorizer library for
fast weight loading.
ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's
checkpoints.
"""

load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
download_dir: Optional[str] = None
model_loader_extra_config: Optional[Union[str, dict]] = field(
default_factory=dict)
ignore_patterns: Optional[Union[List[str], str]] = None

def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {}
Expand All @@ -613,6 +617,13 @@ def __post_init__(self):
model_loader_extra_config)
self._verify_load_format()

if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
logger.info(
"Ignoring the following patterns when downloading weights: %s",
self.ignore_patterns)
else:
self.ignore_patterns = ["original/**/*"]

def _verify_load_format(self) -> None:
if not isinstance(self.load_format, str):
return
Expand Down Expand Up @@ -801,7 +812,9 @@ def __init__(self,
# for higher throughput.
self.max_num_batched_tokens = max(max_model_len, 2048)
if enable_chunked_prefill:
logger.info("Chunked prefill is enabled (EXPERIMENTAL).")
logger.info(
"Chunked prefill is enabled with max_num_batched_tokens=%d.",
max_num_batched_tokens)

self.max_num_seqs = max_num_seqs
self.max_model_len = max_model_len
Expand Down
10 changes: 10 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class EngineArgs:
num_gpu_blocks_override: Optional[int] = None
num_lookahead_slots: int = 0
model_loader_extra_config: Optional[dict] = None
ignore_patterns: Optional[Union[str, List[str]]] = None
preemption_mode: Optional[str] = None

scheduler_delay_factor: float = 0.0
Expand Down Expand Up @@ -619,6 +620,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'corresponding to the chosen load_format. '
'This should be a JSON string that will be '
'parsed into a dictionary.')
parser.add_argument(
'--ignore-patterns',
action="append",
type=str,
default=[],
help="The pattern(s) to ignore when loading the model."
"Default to 'original/**/*' to avoid repeated loading of llama's "
"checkpoints.")
parser.add_argument(
'--preemption-mode',
type=str,
Expand Down Expand Up @@ -824,6 +833,7 @@ def create_engine_config(self, ) -> EngineConfig:
load_format=self.load_format,
download_dir=self.download_dir,
model_loader_extra_config=self.model_loader_extra_config,
ignore_patterns=self.ignore_patterns,
)

prompt_adapter_config = PromptAdapterConfig(
Expand Down
29 changes: 21 additions & 8 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def _maybe_download_from_modelscope(
cache_dir=self.load_config.download_dir,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
revision=revision,
ignore_patterns=self.load_config.ignore_patterns,
)
else:
model_path = model
Expand Down Expand Up @@ -196,9 +197,13 @@ def _prepare_weights(self, model_name_or_path: str,
allow_patterns += ["*.pt"]

if not is_local:
hf_folder = download_weights_from_hf(model_name_or_path,
self.load_config.download_dir,
allow_patterns, revision)
hf_folder = download_weights_from_hf(
model_name_or_path,
self.load_config.download_dir,
allow_patterns,
revision,
ignore_patterns=self.load_config.ignore_patterns,
)
else:
hf_folder = model_name_or_path

Expand Down Expand Up @@ -489,9 +494,13 @@ def _prepare_weights(self, model_name_or_path: str,
return model_name_or_path
else:
allow_patterns = ["*.safetensors"]
return download_weights_from_hf(model_name_or_path,
self.load_config.download_dir,
allow_patterns, revision)
return download_weights_from_hf(
model_name_or_path,
self.load_config.download_dir,
allow_patterns,
revision,
ignore_patterns=self.load_config.ignore_patterns,
)

def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
Expand Down Expand Up @@ -663,8 +672,12 @@ def _get_weight_files(
matching_files = fnmatch.filter(repo_files, pattern)
if matching_files:
hf_folder = download_weights_from_hf(
model_name_or_path, self.load_config.download_dir,
[pattern], revision)
model_name_or_path,
self.load_config.download_dir,
[pattern],
revision,
ignore_patterns=self.load_config.ignore_patterns,
)
return glob.glob(os.path.join(hf_folder, pattern)), pattern

raise RuntimeError(
Expand Down
7 changes: 6 additions & 1 deletion vllm/model_executor/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
import tempfile
from collections import defaultdict
from typing import Any, Generator, Iterable, List, Optional, Tuple
from typing import Any, Generator, Iterable, List, Optional, Tuple, Union

import filelock
import huggingface_hub.constants
Expand Down Expand Up @@ -189,6 +189,7 @@ def download_weights_from_hf(
cache_dir: Optional[str],
allow_patterns: List[str],
revision: Optional[str] = None,
ignore_patterns: Optional[Union[str, List[str]]] = None,
) -> str:
"""Download model weights from Hugging Face Hub.
Expand All @@ -200,6 +201,9 @@ def download_weights_from_hf(
weight files. Files matched by any of the patterns will be
downloaded.
revision (Optional[str]): The revision of the model.
ignore_patterns (Optional[Union[str, List[str]]]): The patterns to
filter out the weight files. Files matched by any of the patterns
will be ignored.
Returns:
str: The path to the downloaded model weights.
Expand All @@ -223,6 +227,7 @@ def download_weights_from_hf(
hf_folder = snapshot_download(
model_name_or_path,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
cache_dir=cache_dir,
tqdm_class=DisabledTqdm,
revision=revision,
Expand Down

0 comments on commit 3eda4ec

Please sign in to comment.