Skip to content

Commit

Permalink
[Frontend][Core] Override HF config.json via CLI (vllm-project#5836)
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
Co-authored-by: DarkLight1337 <[email protected]>
  • Loading branch information
KrishnaM251 and DarkLight1337 authored Nov 9, 2024
1 parent d88bff1 commit b09895a
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 53 deletions.
10 changes: 7 additions & 3 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,10 @@ def test_rope_customization():
trust_remote_code=False,
dtype="float16",
seed=0,
rope_scaling=TEST_ROPE_SCALING,
rope_theta=TEST_ROPE_THETA,
hf_overrides={
"rope_scaling": TEST_ROPE_SCALING,
"rope_theta": TEST_ROPE_THETA,
},
)
assert getattr(llama_model_config.hf_config, "rope_scaling",
None) == TEST_ROPE_SCALING
Expand Down Expand Up @@ -232,7 +234,9 @@ def test_rope_customization():
trust_remote_code=False,
dtype="float16",
seed=0,
rope_scaling=TEST_ROPE_SCALING,
hf_overrides={
"rope_scaling": TEST_ROPE_SCALING,
},
)
assert getattr(longchat_model_config.hf_config, "rope_scaling",
None) == TEST_ROPE_SCALING
Expand Down
30 changes: 22 additions & 8 deletions vllm/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import enum
import json
import warnings
from dataclasses import dataclass, field
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Final, List, Literal,
Mapping, Optional, Set, Tuple, Type, Union)
Expand Down Expand Up @@ -74,9 +75,6 @@ class ModelConfig:
code_revision: The specific revision to use for the model code on
Hugging Face Hub. It can be a branch name, a tag name, or a
commit id. If unspecified, will use the default version.
rope_scaling: Dictionary containing the scaling configuration for the
RoPE embeddings. When using this flag, don't update
`max_position_embeddings` to the expected new maximum.
tokenizer_revision: The specific tokenizer version to use. It can be a
branch name, a tag name, or a commit id. If unspecified, will use
the default version.
Expand Down Expand Up @@ -116,6 +114,7 @@ class ModelConfig:
can not be gathered from the vllm arguments.
config_format: The config format which shall be loaded.
Defaults to 'auto' which defaults to 'hf'.
hf_overrides: Arguments to be forwarded to the HuggingFace config.
mm_processor_kwargs: Arguments to be forwarded to the model's processor
for multi-modal data, e.g., image processor.
pooling_type: Used to configure the pooling method in the embedding
Expand Down Expand Up @@ -146,7 +145,7 @@ def __init__(
allowed_local_media_path: str = "",
revision: Optional[str] = None,
code_revision: Optional[str] = None,
rope_scaling: Optional[dict] = None,
rope_scaling: Optional[Dict[str, Any]] = None,
rope_theta: Optional[float] = None,
tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None,
Expand All @@ -164,6 +163,7 @@ def __init__(
override_neuron_config: Optional[Dict[str, Any]] = None,
config_format: ConfigFormat = ConfigFormat.AUTO,
chat_template_text_format: str = "string",
hf_overrides: Optional[Dict[str, Any]] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
pooling_type: Optional[str] = None,
pooling_norm: Optional[bool] = None,
Expand All @@ -178,8 +178,22 @@ def __init__(
self.seed = seed
self.revision = revision
self.code_revision = code_revision
self.rope_scaling = rope_scaling
self.rope_theta = rope_theta

if hf_overrides is None:
hf_overrides = {}
if rope_scaling is not None:
hf_override: Dict[str, Any] = {"rope_scaling": rope_scaling}
hf_overrides.update(hf_override)
msg = ("`--rope-scaling` will be removed in a future release. "
f"'Please instead use `--hf-overrides '{hf_override!r}'`")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
if rope_theta is not None:
hf_override = {"rope_theta": rope_theta}
hf_overrides.update(hf_override)
msg = ("`--rope-theta` will be removed in a future release. "
f"'Please instead use `--hf-overrides '{hf_override!r}'`")
warnings.warn(DeprecationWarning(msg), stacklevel=2)

# The tokenizer version is consistent with the model version by default.
if tokenizer_revision is None:
self.tokenizer_revision = revision
Expand All @@ -193,8 +207,8 @@ def __init__(
self.disable_sliding_window = disable_sliding_window
self.skip_tokenizer_init = skip_tokenizer_init
self.hf_config = get_config(self.model, trust_remote_code, revision,
code_revision, rope_scaling, rope_theta,
config_format)
code_revision, config_format,
**hf_overrides)
self.hf_text_config = get_hf_text_config(self.hf_config)
self.encoder_config = self._get_encoder_config()
self.hf_image_processor_config = get_hf_image_processor_config(
Expand Down
14 changes: 11 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,9 @@ class EngineArgs:
disable_log_stats: bool = False
revision: Optional[str] = None
code_revision: Optional[str] = None
rope_scaling: Optional[dict] = None
rope_scaling: Optional[Dict[str, Any]] = None
rope_theta: Optional[float] = None
hf_overrides: Optional[Dict[str, Any]] = None
tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None
enforce_eager: Optional[bool] = None
Expand All @@ -140,8 +141,9 @@ class EngineArgs:
# is intended for expert use only. The API may change without
# notice.
tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
tokenizer_pool_extra_config: Optional[dict] = None
tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None
enable_lora: bool = False
max_loras: int = 1
max_lora_rank: int = 16
Expand Down Expand Up @@ -187,7 +189,6 @@ class EngineArgs:
collect_detailed_traces: Optional[str] = None
disable_async_output_proc: bool = False
override_neuron_config: Optional[Dict[str, Any]] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None
scheduling_policy: Literal["fcfs", "priority"] = "fcfs"

# Pooling configuration.
Expand Down Expand Up @@ -512,6 +513,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
help='RoPE theta. Use with `rope_scaling`. In '
'some cases, changing the RoPE theta improves the '
'performance of the scaled model.')
parser.add_argument('--hf-overrides',
type=json.loads,
default=EngineArgs.hf_overrides,
help='Extra arguments for the HuggingFace config.'
'This should be a JSON string that will be '
'parsed into a dictionary.')
parser.add_argument('--enforce-eager',
action='store_true',
help='Always use eager-mode PyTorch. If False, '
Expand Down Expand Up @@ -940,6 +947,7 @@ def create_model_config(self) -> ModelConfig:
code_revision=self.code_revision,
rope_scaling=self.rope_scaling,
rope_theta=self.rope_theta,
hf_overrides=self.hf_overrides,
tokenizer_revision=self.tokenizer_revision,
max_model_len=self.max_model_len,
quantization=self.quantization,
Expand Down
5 changes: 1 addition & 4 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,7 @@ def __init__(
"Initializing an LLM engine (v%s) with config: "
"model=%r, speculative_config=%r, tokenizer=%r, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
"override_neuron_config=%s, "
"rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
"override_neuron_config=%s, tokenizer_revision=%s, "
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
"pipeline_parallel_size=%d, "
Expand All @@ -271,8 +270,6 @@ def __init__(
model_config.tokenizer_mode,
model_config.revision,
model_config.override_neuron_config,
model_config.rope_scaling,
model_config.rope_theta,
model_config.tokenizer_revision,
model_config.trust_remote_code,
model_config.dtype,
Expand Down
7 changes: 6 additions & 1 deletion vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,10 @@ class LLM:
to eager mode. Additionally for encoder-decoder models, if the
sequence length of the encoder input is larger than this, we fall
back to the eager mode.
disable_custom_all_reduce: See ParallelConfig
disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig`
disable_async_output_proc: Disable async output processing.
This may result in lower performance.
hf_overrides: Arguments to be forwarded to the HuggingFace config.
**kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
:ref:`engine_args`)
Expand Down Expand Up @@ -153,6 +156,7 @@ def __init__(
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
disable_async_output_proc: bool = False,
hf_overrides: Optional[dict] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
# After positional args are removed, move this right below `model`
task: TaskOption = "auto",
Expand Down Expand Up @@ -194,6 +198,7 @@ def __init__(
max_seq_len_to_capture=max_seq_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce,
disable_async_output_proc=disable_async_output_proc,
hf_overrides=hf_overrides,
mm_processor_kwargs=mm_processor_kwargs,
pooling_type=pooling_type,
pooling_norm=pooling_norm,
Expand Down
55 changes: 25 additions & 30 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,8 @@ def get_config(
trust_remote_code: bool,
revision: Optional[str] = None,
code_revision: Optional[str] = None,
rope_scaling: Optional[dict] = None,
rope_theta: Optional[float] = None,
config_format: ConfigFormat = ConfigFormat.AUTO,
token: Optional[str] = None,
**kwargs,
) -> PretrainedConfig:
# Separate model folder from file path for GGUF models
Expand All @@ -159,46 +158,51 @@ def get_config(
model = Path(model).parent

if config_format == ConfigFormat.AUTO:
if is_gguf or file_or_path_exists(model,
HF_CONFIG_NAME,
revision=revision,
token=kwargs.get("token")):
if is_gguf or file_or_path_exists(
model, HF_CONFIG_NAME, revision=revision, token=token):
config_format = ConfigFormat.HF
elif file_or_path_exists(model,
MISTRAL_CONFIG_NAME,
revision=revision,
token=kwargs.get("token")):
token=token):
config_format = ConfigFormat.MISTRAL
else:
# If we're in offline mode and found no valid config format, then
# raise an offline mode error to indicate to the user that they
# don't have files cached and may need to go online.
# This is conveniently triggered by calling file_exists().
file_exists(model,
HF_CONFIG_NAME,
revision=revision,
token=kwargs.get("token"))
file_exists(model, HF_CONFIG_NAME, revision=revision, token=token)

raise ValueError(f"No supported config format found in {model}")

if config_format == ConfigFormat.HF:
config_dict, _ = PretrainedConfig.get_config_dict(
model, revision=revision, code_revision=code_revision, **kwargs)
model,
revision=revision,
code_revision=code_revision,
token=token,
**kwargs,
)

# Use custom model class if it's in our registry
model_type = config_dict.get("model_type")
if model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[model_type]
config = config_class.from_pretrained(model,
revision=revision,
code_revision=code_revision)
config = config_class.from_pretrained(
model,
revision=revision,
code_revision=code_revision,
token=token,
**kwargs,
)
else:
try:
config = AutoConfig.from_pretrained(
model,
trust_remote_code=trust_remote_code,
revision=revision,
code_revision=code_revision,
token=token,
**kwargs,
)
except ValueError as e:
Expand All @@ -216,7 +220,7 @@ def get_config(
raise e

elif config_format == ConfigFormat.MISTRAL:
config = load_params_config(model, revision, token=kwargs.get("token"))
config = load_params_config(model, revision, token=token, **kwargs)
else:
raise ValueError(f"Unsupported config format: {config_format}")

Expand All @@ -228,19 +232,6 @@ def get_config(
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
config.update({"architectures": [model_type]})

for key, value in [
("rope_scaling", rope_scaling),
("rope_theta", rope_theta),
]:
if value is not None:
logger.info(
"Updating %s from %r to %r",
key,
getattr(config, key, None),
value,
)
config.update({key: value})

patch_rope_scaling(config)

return config
Expand Down Expand Up @@ -462,13 +453,15 @@ def _reduce_modelconfig(mc: ModelConfig):

def load_params_config(model: Union[str, Path],
revision: Optional[str],
token: Optional[str] = None) -> PretrainedConfig:
token: Optional[str] = None,
**kwargs) -> PretrainedConfig:
# This function loads a params.json config which
# should be used when loading models in mistral format

config_file_name = "params.json"

config_dict = get_hf_file_to_dict(config_file_name, model, revision, token)
assert isinstance(config_dict, dict)

config_mapping = {
"dim": "hidden_size",
Expand Down Expand Up @@ -512,6 +505,8 @@ def recurse_elems(elem: Any):
config_dict["architectures"] = ["PixtralForConditionalGeneration"]
config_dict["model_type"] = "pixtral"

config_dict.update(kwargs)

config = recurse_elems(config_dict)
return config

Expand Down
5 changes: 1 addition & 4 deletions vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ def __init__(
"Initializing an LLM engine (v%s) with config: "
"model=%r, speculative_config=%r, tokenizer=%r, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
"override_neuron_config=%s, "
"rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
"override_neuron_config=%s, tokenizer_revision=%s, "
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
"pipeline_parallel_size=%d, "
Expand All @@ -94,8 +93,6 @@ def __init__(
model_config.tokenizer_mode,
model_config.revision,
model_config.override_neuron_config,
model_config.rope_scaling,
model_config.rope_theta,
model_config.tokenizer_revision,
model_config.trust_remote_code,
model_config.dtype,
Expand Down

0 comments on commit b09895a

Please sign in to comment.