diff --git a/tests/test_config.py b/tests/test_config.py index 66bdb883657c5..36c426d6c51f6 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -200,8 +200,10 @@ def test_rope_customization(): trust_remote_code=False, dtype="float16", seed=0, - rope_scaling=TEST_ROPE_SCALING, - rope_theta=TEST_ROPE_THETA, + hf_overrides={ + "rope_scaling": TEST_ROPE_SCALING, + "rope_theta": TEST_ROPE_THETA, + }, ) assert getattr(llama_model_config.hf_config, "rope_scaling", None) == TEST_ROPE_SCALING @@ -232,7 +234,9 @@ def test_rope_customization(): trust_remote_code=False, dtype="float16", seed=0, - rope_scaling=TEST_ROPE_SCALING, + hf_overrides={ + "rope_scaling": TEST_ROPE_SCALING, + }, ) assert getattr(longchat_model_config.hf_config, "rope_scaling", None) == TEST_ROPE_SCALING diff --git a/vllm/config.py b/vllm/config.py index bed58fcecb5cb..b902499bf5bdc 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,5 +1,6 @@ import enum import json +import warnings from dataclasses import dataclass, field from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Final, List, Literal, Mapping, Optional, Set, Tuple, Type, Union) @@ -74,9 +75,6 @@ class ModelConfig: code_revision: The specific revision to use for the model code on Hugging Face Hub. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. - rope_scaling: Dictionary containing the scaling configuration for the - RoPE embeddings. When using this flag, don't update - `max_position_embeddings` to the expected new maximum. tokenizer_revision: The specific tokenizer version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. @@ -116,6 +114,7 @@ class ModelConfig: can not be gathered from the vllm arguments. config_format: The config format which shall be loaded. Defaults to 'auto' which defaults to 'hf'. + hf_overrides: Arguments to be forwarded to the HuggingFace config. mm_processor_kwargs: Arguments to be forwarded to the model's processor for multi-modal data, e.g., image processor. pooling_type: Used to configure the pooling method in the embedding @@ -146,7 +145,7 @@ def __init__( allowed_local_media_path: str = "", revision: Optional[str] = None, code_revision: Optional[str] = None, - rope_scaling: Optional[dict] = None, + rope_scaling: Optional[Dict[str, Any]] = None, rope_theta: Optional[float] = None, tokenizer_revision: Optional[str] = None, max_model_len: Optional[int] = None, @@ -164,6 +163,7 @@ def __init__( override_neuron_config: Optional[Dict[str, Any]] = None, config_format: ConfigFormat = ConfigFormat.AUTO, chat_template_text_format: str = "string", + hf_overrides: Optional[Dict[str, Any]] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None, pooling_type: Optional[str] = None, pooling_norm: Optional[bool] = None, @@ -178,8 +178,22 @@ def __init__( self.seed = seed self.revision = revision self.code_revision = code_revision - self.rope_scaling = rope_scaling - self.rope_theta = rope_theta + + if hf_overrides is None: + hf_overrides = {} + if rope_scaling is not None: + hf_override: Dict[str, Any] = {"rope_scaling": rope_scaling} + hf_overrides.update(hf_override) + msg = ("`--rope-scaling` will be removed in a future release. " + f"'Please instead use `--hf-overrides '{hf_override!r}'`") + warnings.warn(DeprecationWarning(msg), stacklevel=2) + if rope_theta is not None: + hf_override = {"rope_theta": rope_theta} + hf_overrides.update(hf_override) + msg = ("`--rope-theta` will be removed in a future release. " + f"'Please instead use `--hf-overrides '{hf_override!r}'`") + warnings.warn(DeprecationWarning(msg), stacklevel=2) + # The tokenizer version is consistent with the model version by default. if tokenizer_revision is None: self.tokenizer_revision = revision @@ -193,8 +207,8 @@ def __init__( self.disable_sliding_window = disable_sliding_window self.skip_tokenizer_init = skip_tokenizer_init self.hf_config = get_config(self.model, trust_remote_code, revision, - code_revision, rope_scaling, rope_theta, - config_format) + code_revision, config_format, + **hf_overrides) self.hf_text_config = get_hf_text_config(self.hf_config) self.encoder_config = self._get_encoder_config() self.hf_image_processor_config = get_hf_image_processor_config( diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 8c5b442e9f624..95d55e86e08e8 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -128,8 +128,9 @@ class EngineArgs: disable_log_stats: bool = False revision: Optional[str] = None code_revision: Optional[str] = None - rope_scaling: Optional[dict] = None + rope_scaling: Optional[Dict[str, Any]] = None rope_theta: Optional[float] = None + hf_overrides: Optional[Dict[str, Any]] = None tokenizer_revision: Optional[str] = None quantization: Optional[str] = None enforce_eager: Optional[bool] = None @@ -140,8 +141,9 @@ class EngineArgs: # is intended for expert use only. The API may change without # notice. tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray" - tokenizer_pool_extra_config: Optional[dict] = None + tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None limit_mm_per_prompt: Optional[Mapping[str, int]] = None + mm_processor_kwargs: Optional[Dict[str, Any]] = None enable_lora: bool = False max_loras: int = 1 max_lora_rank: int = 16 @@ -187,7 +189,6 @@ class EngineArgs: collect_detailed_traces: Optional[str] = None disable_async_output_proc: bool = False override_neuron_config: Optional[Dict[str, Any]] = None - mm_processor_kwargs: Optional[Dict[str, Any]] = None scheduling_policy: Literal["fcfs", "priority"] = "fcfs" # Pooling configuration. @@ -512,6 +513,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help='RoPE theta. Use with `rope_scaling`. In ' 'some cases, changing the RoPE theta improves the ' 'performance of the scaled model.') + parser.add_argument('--hf-overrides', + type=json.loads, + default=EngineArgs.hf_overrides, + help='Extra arguments for the HuggingFace config.' + 'This should be a JSON string that will be ' + 'parsed into a dictionary.') parser.add_argument('--enforce-eager', action='store_true', help='Always use eager-mode PyTorch. If False, ' @@ -940,6 +947,7 @@ def create_model_config(self) -> ModelConfig: code_revision=self.code_revision, rope_scaling=self.rope_scaling, rope_theta=self.rope_theta, + hf_overrides=self.hf_overrides, tokenizer_revision=self.tokenizer_revision, max_model_len=self.max_model_len, quantization=self.quantization, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 5d321fc98aeb6..d550b1d244af8 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -248,8 +248,7 @@ def __init__( "Initializing an LLM engine (v%s) with config: " "model=%r, speculative_config=%r, tokenizer=%r, " "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " - "override_neuron_config=%s, " - "rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, " + "override_neuron_config=%s, tokenizer_revision=%s, " "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " "pipeline_parallel_size=%d, " @@ -271,8 +270,6 @@ def __init__( model_config.tokenizer_mode, model_config.revision, model_config.override_neuron_config, - model_config.rope_scaling, - model_config.rope_theta, model_config.tokenizer_revision, model_config.trust_remote_code, model_config.dtype, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index d8b60a5e01471..f830839776364 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -98,7 +98,10 @@ class LLM: to eager mode. Additionally for encoder-decoder models, if the sequence length of the encoder input is larger than this, we fall back to the eager mode. - disable_custom_all_reduce: See ParallelConfig + disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig` + disable_async_output_proc: Disable async output processing. + This may result in lower performance. + hf_overrides: Arguments to be forwarded to the HuggingFace config. **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See :ref:`engine_args`) @@ -153,6 +156,7 @@ def __init__( max_seq_len_to_capture: int = 8192, disable_custom_all_reduce: bool = False, disable_async_output_proc: bool = False, + hf_overrides: Optional[dict] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None, # After positional args are removed, move this right below `model` task: TaskOption = "auto", @@ -194,6 +198,7 @@ def __init__( max_seq_len_to_capture=max_seq_len_to_capture, disable_custom_all_reduce=disable_custom_all_reduce, disable_async_output_proc=disable_async_output_proc, + hf_overrides=hf_overrides, mm_processor_kwargs=mm_processor_kwargs, pooling_type=pooling_type, pooling_norm=pooling_norm, diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 6b38ee31c2657..14d9518364d26 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -146,9 +146,8 @@ def get_config( trust_remote_code: bool, revision: Optional[str] = None, code_revision: Optional[str] = None, - rope_scaling: Optional[dict] = None, - rope_theta: Optional[float] = None, config_format: ConfigFormat = ConfigFormat.AUTO, + token: Optional[str] = None, **kwargs, ) -> PretrainedConfig: # Separate model folder from file path for GGUF models @@ -159,39 +158,43 @@ def get_config( model = Path(model).parent if config_format == ConfigFormat.AUTO: - if is_gguf or file_or_path_exists(model, - HF_CONFIG_NAME, - revision=revision, - token=kwargs.get("token")): + if is_gguf or file_or_path_exists( + model, HF_CONFIG_NAME, revision=revision, token=token): config_format = ConfigFormat.HF elif file_or_path_exists(model, MISTRAL_CONFIG_NAME, revision=revision, - token=kwargs.get("token")): + token=token): config_format = ConfigFormat.MISTRAL else: # If we're in offline mode and found no valid config format, then # raise an offline mode error to indicate to the user that they # don't have files cached and may need to go online. # This is conveniently triggered by calling file_exists(). - file_exists(model, - HF_CONFIG_NAME, - revision=revision, - token=kwargs.get("token")) + file_exists(model, HF_CONFIG_NAME, revision=revision, token=token) raise ValueError(f"No supported config format found in {model}") if config_format == ConfigFormat.HF: config_dict, _ = PretrainedConfig.get_config_dict( - model, revision=revision, code_revision=code_revision, **kwargs) + model, + revision=revision, + code_revision=code_revision, + token=token, + **kwargs, + ) # Use custom model class if it's in our registry model_type = config_dict.get("model_type") if model_type in _CONFIG_REGISTRY: config_class = _CONFIG_REGISTRY[model_type] - config = config_class.from_pretrained(model, - revision=revision, - code_revision=code_revision) + config = config_class.from_pretrained( + model, + revision=revision, + code_revision=code_revision, + token=token, + **kwargs, + ) else: try: config = AutoConfig.from_pretrained( @@ -199,6 +202,7 @@ def get_config( trust_remote_code=trust_remote_code, revision=revision, code_revision=code_revision, + token=token, **kwargs, ) except ValueError as e: @@ -216,7 +220,7 @@ def get_config( raise e elif config_format == ConfigFormat.MISTRAL: - config = load_params_config(model, revision, token=kwargs.get("token")) + config = load_params_config(model, revision, token=token, **kwargs) else: raise ValueError(f"Unsupported config format: {config_format}") @@ -228,19 +232,6 @@ def get_config( model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type] config.update({"architectures": [model_type]}) - for key, value in [ - ("rope_scaling", rope_scaling), - ("rope_theta", rope_theta), - ]: - if value is not None: - logger.info( - "Updating %s from %r to %r", - key, - getattr(config, key, None), - value, - ) - config.update({key: value}) - patch_rope_scaling(config) return config @@ -462,13 +453,15 @@ def _reduce_modelconfig(mc: ModelConfig): def load_params_config(model: Union[str, Path], revision: Optional[str], - token: Optional[str] = None) -> PretrainedConfig: + token: Optional[str] = None, + **kwargs) -> PretrainedConfig: # This function loads a params.json config which # should be used when loading models in mistral format config_file_name = "params.json" config_dict = get_hf_file_to_dict(config_file_name, model, revision, token) + assert isinstance(config_dict, dict) config_mapping = { "dim": "hidden_size", @@ -512,6 +505,8 @@ def recurse_elems(elem: Any): config_dict["architectures"] = ["PixtralForConditionalGeneration"] config_dict["model_type"] = "pixtral" + config_dict.update(kwargs) + config = recurse_elems(config_dict) return config diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 81dc01ae2d8e7..f805c5e69bc1c 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -74,8 +74,7 @@ def __init__( "Initializing an LLM engine (v%s) with config: " "model=%r, speculative_config=%r, tokenizer=%r, " "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " - "override_neuron_config=%s, " - "rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, " + "override_neuron_config=%s, tokenizer_revision=%s, " "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " "pipeline_parallel_size=%d, " @@ -94,8 +93,6 @@ def __init__( model_config.tokenizer_mode, model_config.revision, model_config.override_neuron_config, - model_config.rope_scaling, - model_config.rope_theta, model_config.tokenizer_revision, model_config.trust_remote_code, model_config.dtype,