Skip to content

Commit

Permalink
Extra check for if the files exists
Browse files Browse the repository at this point in the history
Signed-off-by: Flavia Beo <[email protected]>
  • Loading branch information
flaviabeo committed Oct 25, 2024
1 parent 1bcd3e8 commit 69222e4
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 23 deletions.
3 changes: 1 addition & 2 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,7 @@ def __init__(
model_config.pooling_config.normalize
if model_config.pooling_config is not None else None,
model_config.chat_template_text_format,
model_config.mm_processor_kwargs
)
model_config.mm_processor_kwargs)
# TODO(woosuk): Print more configs in debug mode.
self.model_config = model_config
self.cache_config = cache_config
Expand Down
62 changes: 41 additions & 21 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def get_config(
raise e

elif config_format == ConfigFormat.MISTRAL:
config = load_params_config(model, revision)
config = load_params_config(model, revision, token=kwargs.get("token"))
else:
raise ValueError(f"Unsupported config format: {config_format}")

Expand Down Expand Up @@ -235,7 +235,10 @@ def get_config(
return config


def get_hf_file_to_dict(file_name, model, revision):
def get_hf_file_to_dict(file_name,
model,
revision,
token: Optional[str] = None):
"""
Downloads a file from the Hugging Face Hub and returns
its contents as a dictionary.
Expand All @@ -244,29 +247,39 @@ def get_hf_file_to_dict(file_name, model, revision):
- file_name (str): The name of the file to download.
- model (str): The name of the model on the Hugging Face Hub.
- revision (str): The specific version of the model.
- token (str): The Hugging Face authentication token.
Returns:
- config_dict (dict): A dictionary containing
the contents of the downloaded file.
"""
file_path = Path(model) / file_name

if not file_path.is_file():
try:
hf_hub_file = hf_hub_download(model, file_name, revision=revision)
except (RepositoryNotFoundError, RevisionNotFoundError,
EntryNotFoundError, LocalEntryNotFoundError) as e:
logger.debug("File or repository not found in hf_hub_download", e)
return None
file_path = Path(hf_hub_file)

with open(file_path, "r") as file:
config_dict = json.load(file)
if file_or_path_exists(model=model,
config_name=file_name,
revision=revision,
token=token):

return config_dict
if not file_path.is_file():
try:
hf_hub_file = hf_hub_download(model,
file_name,
revision=revision)
except (RepositoryNotFoundError, RevisionNotFoundError,
EntryNotFoundError, LocalEntryNotFoundError) as e:
logger.debug("File or repository not found in hf_hub_download",
e)
return None
file_path = Path(hf_hub_file)

with open(file_path, "r") as file:
config_dict = json.load(file)

return config_dict
return None


def get_pooling_config(model, revision='main'):
def get_pooling_config(model, revision='main', token: Optional[str] = None):
"""
This function gets the pooling and normalize
config from the model - only applies to
Expand All @@ -283,7 +296,8 @@ def get_pooling_config(model, revision='main'):
"""

modules_file_name = "modules.json"
modules_dict = get_hf_file_to_dict(modules_file_name, model, revision)
modules_dict = get_hf_file_to_dict(modules_file_name, model, revision,
token)

if modules_dict is None:
return None
Expand All @@ -299,7 +313,8 @@ def get_pooling_config(model, revision='main'):
if pooling:

pooling_file_name = "{}/config.json".format(pooling["path"])
pooling_dict = get_hf_file_to_dict(pooling_file_name, model, revision)
pooling_dict = get_hf_file_to_dict(pooling_file_name, model, revision,
token)
pooling_type_name = next(
(item for item, val in pooling_dict.items() if val is True), None)

Expand All @@ -308,7 +323,9 @@ def get_pooling_config(model, revision='main'):
return None


def get_sentence_transformer_tokenizer_config(model, revision='main'):
def get_sentence_transformer_tokenizer_config(model,
revision='main',
token: Optional[str] = None):
"""
Returns the tokenization configuration dictionary for a
given Sentence Transformer BERT model.
Expand All @@ -318,6 +335,7 @@ def get_sentence_transformer_tokenizer_config(model, revision='main'):
BERT model.
- revision (str, optional): The revision of the m
odel to use. Defaults to 'main'.
- token (str): A Hugging Face access token.
Returns:
- dict: A dictionary containing the configuration parameters
Expand All @@ -332,7 +350,7 @@ def get_sentence_transformer_tokenizer_config(model, revision='main'):
"sentence_xlm-roberta_config.json",
"sentence_xlnet_config.json",
]:
bert_dict = get_hf_file_to_dict(config_name, model, revision)
bert_dict = get_hf_file_to_dict(config_name, model, revision, token)
if bert_dict:
break

Expand Down Expand Up @@ -406,13 +424,15 @@ def _reduce_modelconfig(mc: ModelConfig):
exc_info=e)


def load_params_config(model, revision) -> PretrainedConfig:
def load_params_config(model,
revision,
token: Optional[str] = None) -> PretrainedConfig:
# This function loads a params.json config which
# should be used when loading models in mistral format

config_file_name = "params.json"

config_dict = get_hf_file_to_dict(config_file_name, model, revision)
config_dict = get_hf_file_to_dict(config_file_name, model, revision, token)

config_mapping = {
"dim": "hidden_size",
Expand Down

0 comments on commit 69222e4

Please sign in to comment.