Skip to content

Commit

Permalink
[Core] Support offline use of local cache for models (#4374)
Browse files Browse the repository at this point in the history
Signed-off-by: Prashant Gupta <[email protected]>
Co-authored-by: Travis Johnson <[email protected]>
  • Loading branch information
prashantgupta24 and tjohnson31415 authored Apr 27, 2024
1 parent 81661da commit d6e520e
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 27 deletions.
30 changes: 29 additions & 1 deletion tests/model_executor/weight_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import os
import tempfile

import huggingface_hub.constants
import pytest
from huggingface_hub.utils import LocalEntryNotFoundError

from vllm.model_executor.model_loader.weight_utils import enable_hf_transfer
from vllm.model_executor.model_loader.weight_utils import (
download_weights_from_hf, enable_hf_transfer)


def test_hf_transfer_auto_activation():
Expand All @@ -22,5 +25,30 @@ def test_hf_transfer_auto_activation():
HF_TRANFER_ACTIVE)


def test_download_weights_from_hf():
with tempfile.TemporaryDirectory() as tmpdir:
# assert LocalEntryNotFoundError error is thrown
# if offline is set and model is not cached
huggingface_hub.constants.HF_HUB_OFFLINE = True
with pytest.raises(LocalEntryNotFoundError):
download_weights_from_hf("facebook/opt-125m",
allow_patterns=["*.safetensors", "*.bin"],
cache_dir=tmpdir)

# download the model
huggingface_hub.constants.HF_HUB_OFFLINE = False
download_weights_from_hf("facebook/opt-125m",
allow_patterns=["*.safetensors", "*.bin"],
cache_dir=tmpdir)

# now it should work offline
huggingface_hub.constants.HF_HUB_OFFLINE = True
assert download_weights_from_hf(
"facebook/opt-125m",
allow_patterns=["*.safetensors", "*.bin"],
cache_dir=tmpdir) is not None


if __name__ == "__main__":
test_hf_transfer_auto_activation()
test_download_weights_from_hf()
5 changes: 4 additions & 1 deletion vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Generator, List, Optional, Tuple, Type

import huggingface_hub
import torch
from torch import nn

Expand Down Expand Up @@ -131,7 +132,9 @@ def _maybe_download_from_modelscope(
model_path = snapshot_download(
model_id=model,
cache_dir=self.load_config.download_dir,
revision=revision)
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
revision=revision,
)
else:
model_path = model
return model_path
Expand Down
59 changes: 34 additions & 25 deletions vllm/model_executor/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,14 @@ def get_quant_config(model_config: ModelConfig,
if not is_local:
# Download the config files.
with get_lock(model_name_or_path, load_config.download_dir):
hf_folder = snapshot_download(model_name_or_path,
revision=model_config.revision,
allow_patterns="*.json",
cache_dir=load_config.download_dir,
tqdm_class=DisabledTqdm)
hf_folder = snapshot_download(
model_name_or_path,
revision=model_config.revision,
allow_patterns="*.json",
cache_dir=load_config.download_dir,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
tqdm_class=DisabledTqdm,
)
else:
hf_folder = model_name_or_path

Expand Down Expand Up @@ -161,12 +164,14 @@ def get_quant_config(model_config: ModelConfig,
return quant_cls.from_config(config)


def download_weights_from_hf(model_name_or_path: str,
cache_dir: Optional[str],
allow_patterns: List[str],
revision: Optional[str] = None) -> str:
def download_weights_from_hf(
model_name_or_path: str,
cache_dir: Optional[str],
allow_patterns: List[str],
revision: Optional[str] = None,
) -> str:
"""Download model weights from Hugging Face Hub.
Args:
model_name_or_path (str): The model name or path.
cache_dir (Optional[str]): The cache directory to store the model
Expand All @@ -179,26 +184,30 @@ def download_weights_from_hf(model_name_or_path: str,
Returns:
str: The path to the downloaded model weights.
"""
# Before we download we look at that is available:
fs = HfFileSystem()
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)

# depending on what is available we download different things
for pattern in allow_patterns:
matching = fnmatch.filter(file_list, pattern)
if len(matching) > 0:
allow_patterns = [pattern]
break
if not huggingface_hub.constants.HF_HUB_OFFLINE:
# Before we download we look at that is available:
fs = HfFileSystem()
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)

# depending on what is available we download different things
for pattern in allow_patterns:
matching = fnmatch.filter(file_list, pattern)
if len(matching) > 0:
allow_patterns = [pattern]
break

logger.info("Using model weights format %s", allow_patterns)
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir):
hf_folder = snapshot_download(model_name_or_path,
allow_patterns=allow_patterns,
cache_dir=cache_dir,
tqdm_class=DisabledTqdm,
revision=revision)
hf_folder = snapshot_download(
model_name_or_path,
allow_patterns=allow_patterns,
cache_dir=cache_dir,
tqdm_class=DisabledTqdm,
revision=revision,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
)
return hf_folder


Expand Down
2 changes: 2 additions & 0 deletions vllm/transformers_utils/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from typing import Optional, Union

import huggingface_hub
from transformers import (AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast)

Expand Down Expand Up @@ -76,6 +77,7 @@ def get_tokenizer(
model_id=tokenizer_name,
cache_dir=download_dir,
revision=revision,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
# Ignore weights - we only need the tokenizer.
ignore_file_pattern=["*.pt", "*.safetensors", "*.bin"])
tokenizer_name = tokenizer_path
Expand Down

0 comments on commit d6e520e

Please sign in to comment.