diff --git a/tests/lora/test_long_context.py b/tests/lora/test_long_context.py index c8edb02a88d4b..eada902c891f7 100644 --- a/tests/lora/test_long_context.py +++ b/tests/lora/test_long_context.py @@ -138,13 +138,7 @@ def test_rotary_emb_replaced(dist_init): enable_lora=True) engine_config = engine_args.create_engine_config() model_runner = ModelRunner( - model_config=engine_config.model_config, - parallel_config=engine_config.parallel_config, - scheduler_config=engine_config.scheduler_config, - device_config=engine_config.device_config, - cache_config=engine_config.cache_config, - load_config=engine_config.load_config, - lora_config=engine_config.lora_config, + vllm_config=engine_config, is_driver_worker=True, ) model_runner.load_model() diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index 2f7ac85507425..9d814f657ac43 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -4,7 +4,8 @@ from unittest.mock import patch from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig) + ModelConfig, ParallelConfig, SchedulerConfig, + VllmConfig) from vllm.lora.models import LoRAMapping from vllm.lora.request import LoRARequest from vllm.worker.worker import Worker @@ -12,7 +13,7 @@ @patch.dict(os.environ, {"RANK": "0"}) def test_worker_apply_lora(sql_lora_files): - worker = Worker( + vllm_config = VllmConfig( model_config=ModelConfig( "meta-llama/Llama-2-7b-hf", task="auto", @@ -34,10 +35,13 @@ def test_worker_apply_lora(sql_lora_files): gpu_memory_utilization=1., swap_space=0, cache_dtype="auto"), - local_rank=0, - rank=0, lora_config=LoRAConfig(max_lora_rank=8, max_cpu_loras=32, max_loras=32), + ) + worker = Worker( + vllm_config=vllm_config, + local_rank=0, + rank=0, distributed_init_method=f"file://{tempfile.mkstemp()[1]}", ) worker.init_device() diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index f683942a5854b..6cf0cfb09b8fa 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -81,12 +81,7 @@ def create_worker(cls: Callable[..., T], get_ip(), get_open_port()) worker = cls( - model_config=engine_config.model_config, - parallel_config=engine_config.parallel_config, - scheduler_config=engine_config.scheduler_config, - device_config=engine_config.device_config, - cache_config=engine_config.cache_config, - load_config=engine_config.load_config, + vllm_config=engine_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method, diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index e75884a7395e2..9e166ae64dbfb 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -19,14 +19,7 @@ def _create_model_runner(model: str, *args, engine_args = EngineArgs(model, *args, **kwargs) engine_config = engine_args.create_engine_config() model_runner = EncoderDecoderModelRunner( - model_config=engine_config.model_config, - parallel_config=engine_config.parallel_config, - scheduler_config=engine_config.scheduler_config, - device_config=engine_config.device_config, - cache_config=engine_config.cache_config, - load_config=engine_config.load_config, - lora_config=engine_config.lora_config, - prompt_adapter_config=engine_config.prompt_adapter_config, + vllm_config=engine_config, is_driver_worker=True, ) return model_runner diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index fe97199bac62d..433a9b30ba57a 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -16,15 +16,7 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner: engine_args = EngineArgs(model, *args, **kwargs) engine_config = engine_args.create_engine_config() model_runner = ModelRunner( - model_config=engine_config.model_config, - parallel_config=engine_config.parallel_config, - scheduler_config=engine_config.scheduler_config, - device_config=engine_config.device_config, - cache_config=engine_config.cache_config, - load_config=engine_config.load_config, - lora_config=engine_config.lora_config, - prompt_adapter_config=engine_config.prompt_adapter_config, - observability_config=engine_config.observability_config, + vllm_config=engine_config, is_driver_worker=True, ) return model_runner diff --git a/tests/worker/test_profile.py b/tests/worker/test_profile.py index acd2ed6836365..194ea2aa506f4 100644 --- a/tests/worker/test_profile.py +++ b/tests/worker/test_profile.py @@ -24,12 +24,7 @@ def test_gpu_memory_profiling(): distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) worker = Worker( - model_config=engine_config.model_config, - parallel_config=engine_config.parallel_config, - scheduler_config=engine_config.scheduler_config, - device_config=engine_config.device_config, - cache_config=engine_config.cache_config, - load_config=engine_config.load_config, + vllm_config=engine_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method, diff --git a/tests/worker/test_swap.py b/tests/worker/test_swap.py index 7aa439ba0a154..acede959f59f8 100644 --- a/tests/worker/test_swap.py +++ b/tests/worker/test_swap.py @@ -19,12 +19,7 @@ def test_swap() -> None: distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) worker = Worker( - model_config=engine_config.model_config, - parallel_config=engine_config.parallel_config, - scheduler_config=engine_config.scheduler_config, - device_config=engine_config.device_config, - cache_config=engine_config.cache_config, - load_config=engine_config.load_config, + vllm_config=engine_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method, diff --git a/vllm/config.py b/vllm/config.py index fe2fd3b1d59d3..9a36a342dac2f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,6 +1,6 @@ import enum import json -from dataclasses import dataclass, field, fields +from dataclasses import dataclass, field from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Final, List, Literal, Mapping, Optional, Set, Tuple, Type, Union) @@ -1954,9 +1954,9 @@ def __post_init__(self): f"installed. Original error:\n{otel_import_error_traceback}") -@dataclass(frozen=True) -class EngineConfig: - """Dataclass which contains all engine-related configuration. This +@dataclass +class VllmConfig: + """Dataclass which contains all vllm-related configuration. This simplifies passing around the distinct configurations in the codebase. """ @@ -1966,11 +1966,11 @@ class EngineConfig: scheduler_config: SchedulerConfig device_config: DeviceConfig load_config: LoadConfig - lora_config: Optional[LoRAConfig] - speculative_config: Optional[SpeculativeConfig] - decoding_config: Optional[DecodingConfig] - observability_config: Optional[ObservabilityConfig] - prompt_adapter_config: Optional[PromptAdapterConfig] + lora_config: Optional[LoRAConfig] = None + speculative_config: Optional[SpeculativeConfig] = None + decoding_config: Optional[DecodingConfig] = None + observability_config: Optional[ObservabilityConfig] = None + prompt_adapter_config: Optional[PromptAdapterConfig] = None def __post_init__(self): """Verify configs are valid & consistent with each other. @@ -1988,9 +1988,3 @@ def __post_init__(self): if self.prompt_adapter_config: self.prompt_adapter_config.verify_with_model_config( self.model_config) - - def to_dict(self): - """Return the configs as a dictionary, for use in **kwargs. - """ - return dict( - (field.name, getattr(self, field.name)) for field in fields(self)) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index a67f489df3868..685877299a61f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -9,10 +9,11 @@ import vllm.envs as envs from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig, - DeviceConfig, EngineConfig, LoadConfig, LoadFormat, - LoRAConfig, ModelConfig, ObservabilityConfig, - ParallelConfig, PromptAdapterConfig, SchedulerConfig, - SpeculativeConfig, TaskOption, TokenizerPoolConfig) + DeviceConfig, LoadConfig, LoadFormat, LoRAConfig, + ModelConfig, ObservabilityConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig, TaskOption, TokenizerPoolConfig, + VllmConfig) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS @@ -955,7 +956,7 @@ def create_load_config(self) -> LoadConfig: ignore_patterns=self.ignore_patterns, ) - def create_engine_config(self) -> EngineConfig: + def create_engine_config(self) -> VllmConfig: # gguf file needs a specific model loader and doesn't use hf_repo if check_gguf_file(self.model): self.quantization = self.load_format = "gguf" @@ -1167,7 +1168,7 @@ def create_engine_config(self) -> EngineConfig: or "all" in detailed_trace_modules, ) - return EngineConfig( + return VllmConfig( model_config=model_config, cache_config=cache_config, parallel_config=parallel_config, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 6aeaf484a22b4..b0fdc67776bbd 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -7,8 +7,8 @@ from weakref import ReferenceType import vllm.envs as envs -from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) +from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig, VllmConfig) from vllm.core.scheduler import SchedulerOutputs from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_timeout import asyncio_timeout @@ -604,7 +604,7 @@ def __del__(self): @classmethod def _get_executor_cls( - cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]: + cls, engine_config: VllmConfig) -> Type[ExecutorAsyncBase]: distributed_executor_backend = ( engine_config.parallel_config.distributed_executor_backend) if isinstance(distributed_executor_backend, type): @@ -663,7 +663,7 @@ def _get_executor_cls( def from_engine_args( cls, engine_args: AsyncEngineArgs, - engine_config: Optional[EngineConfig] = None, + engine_config: Optional[VllmConfig] = None, start_engine_loop: bool = True, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index cbe01cb3ca103..1d6060cc278e8 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -13,8 +13,9 @@ from typing_extensions import TypeIs, TypeVar import vllm.envs as envs -from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, - ObservabilityConfig, ParallelConfig, SchedulerConfig) +from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, + ObservabilityConfig, ParallelConfig, SchedulerConfig, + VllmConfig) from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler, SchedulerOutputs) from vllm.engine.arg_utils import EngineArgs @@ -219,7 +220,7 @@ def validate_outputs( def __init__( self, - vllm_config: EngineConfig, + vllm_config: VllmConfig, executor_class: Type[ExecutorBase], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, @@ -500,7 +501,7 @@ def _initialize_kv_caches(self) -> None: @classmethod def _get_executor_cls(cls, - engine_config: EngineConfig) -> Type[ExecutorBase]: + engine_config: VllmConfig) -> Type[ExecutorBase]: distributed_executor_backend = ( engine_config.parallel_config.distributed_executor_backend) # Initialize the cluster and specify the executor class. diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 6e6630b3ff55f..7f1ca621d91c4 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -13,7 +13,7 @@ from zmq.asyncio import Socket from vllm import PoolingParams -from vllm.config import DecodingConfig, EngineConfig, ModelConfig +from vllm.config import DecodingConfig, ModelConfig, VllmConfig from vllm.core.scheduler import SchedulerOutputs from vllm.engine.arg_utils import AsyncEngineArgs # yapf conflicts with isort for this block @@ -78,7 +78,7 @@ class MQLLMEngineClient(EngineClient): every N seconds, confirming the engine is healthy """ - def __init__(self, ipc_path: str, engine_config: EngineConfig, + def __init__(self, ipc_path: str, engine_config: VllmConfig, engine_pid: int): self.context = zmq.asyncio.Context() self._errored_with: Optional[BaseException] = None diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index e32993e0e452e..ab3ebb4e43d18 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -138,18 +138,11 @@ def _create_worker( assert self.distributed_init_method is not None kwargs = dict( - model_config=self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, - load_config=self.load_config, + vllm_config=self.vllm_config, local_rank=local_rank, rank=rank, distributed_init_method=self.distributed_init_method, - lora_config=self.lora_config, kv_cache_dtype=self.cache_config.cache_dtype, - prompt_adapter_config=self.prompt_adapter_config, is_driver_worker=rank == 0, ) wrapper.init_worker(**kwargs) diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 2248eecd1849f..9cba189dd57f9 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import List, Optional, Set, Tuple -from vllm.config import EngineConfig +from vllm.config import VllmConfig from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput from vllm.prompt_adapter.request import PromptAdapterRequest @@ -20,7 +20,7 @@ class ExecutorBase(ABC): def __init__( self, - vllm_config: EngineConfig, + vllm_config: VllmConfig, ) -> None: self.vllm_config = vllm_config self.model_config = vllm_config.model_config diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index ed30d3186a453..c65d0836e5ff7 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -49,21 +49,12 @@ def _get_worker_kwargs( distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) return dict( - model_config=self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, - load_config=self.load_config, + vllm_config=self.vllm_config, local_rank=local_rank, rank=rank, distributed_init_method=distributed_init_method, - lora_config=self.lora_config, - speculative_config=self.speculative_config, - prompt_adapter_config=self.prompt_adapter_config, is_driver_worker=(not self.parallel_config) or (rank % self.parallel_config.tensor_parallel_size == 0), - observability_config=self.observability_config, ) def _get_worker_module_and_class( diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index f2fcfa58b26e1..02d37cd7fbf23 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -29,11 +29,7 @@ def _init_worker(self): distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) self.driver_worker = NeuronWorker( - model_config=self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, + vllm_config=self.vllm_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method) diff --git a/vllm/executor/openvino_executor.py b/vllm/executor/openvino_executor.py index d0c0333854dae..d06b0ccb7906e 100644 --- a/vllm/executor/openvino_executor.py +++ b/vllm/executor/openvino_executor.py @@ -48,16 +48,10 @@ def _init_worker(self): get_ip(), get_open_port()) self.driver_worker = OpenVINOWorker( ov_core=self.ov_core, - model_config=self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, - load_config=self.load_config, + vllm_config=self.vllm_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method, - lora_config=self.lora_config, kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=True, ) diff --git a/vllm/executor/tpu_executor.py b/vllm/executor/tpu_executor.py index 972649dedf33e..e37e8973790db 100644 --- a/vllm/executor/tpu_executor.py +++ b/vllm/executor/tpu_executor.py @@ -44,12 +44,7 @@ def _get_worker_kwargs( distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) return dict( - model_config=self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, - load_config=self.load_config, + vllm_config=self.vllm_config, local_rank=local_rank, rank=rank, distributed_init_method=distributed_init_method, diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 3aa999fcb9ebb..17cc0ad1a4a3a 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -17,9 +17,6 @@ "Draft model speculative decoding currently only supports" "CUDA and ROCm flash attention backend.") from err -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig) from vllm.logger import init_logger from vllm.multimodal import MultiModalInputs from vllm.sequence import ExecuteModelRequest, IntermediateTensors @@ -49,40 +46,13 @@ class TP1DraftModelRunner(ModelRunner): any broadcasting inside execute_model). """ - def __init__( - self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - kv_cache_dtype: Optional[str] = "auto", - is_driver_worker: bool = False, - prompt_adapter_config: Optional[PromptAdapterConfig] = None, - return_hidden_states: bool = False, - observability_config: Optional[ObservabilityConfig] = None, - ): - if return_hidden_states: + def __init__(self, *args, **kwargs): + if kwargs.get("return_hidden_states"): raise ValueError( "return_hidden_states is not supported for TP1DraftModelRunner." ) - super().__init__( - model_config=model_config, - parallel_config=parallel_config, - scheduler_config=scheduler_config, - device_config=device_config, - cache_config=cache_config, - load_config=load_config, - lora_config=lora_config, - kv_cache_dtype=kv_cache_dtype, - is_driver_worker=is_driver_worker, - prompt_adapter_config=prompt_adapter_config, - return_hidden_states=return_hidden_states, - observability_config=observability_config, - ) + super().__init__(*args, **kwargs) def _update_sampling_metadata(self, sampling_metadata, num_seqs, num_queries): diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py index a777e5c3f22a7..debb3b2d5ec30 100644 --- a/vllm/spec_decode/ngram_worker.py +++ b/vllm/spec_decode/ngram_worker.py @@ -21,7 +21,7 @@ class NGramWorker(NonLLMProposerWorkerBase): def __init__(self, *args, **kwargs): # Get local_rank/vocab_size from kwargs attribute self.local_rank = kwargs["local_rank"] - self.vocab_size = kwargs["model_config"].get_vocab_size() + self.vocab_size = kwargs["vllm_config"].model_config.get_vocab_size() # Lazy initialization list. self._proposer: Top1Proposer diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 9f7ef2f8d851c..a402181b13db8 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -1,10 +1,11 @@ +import copy from collections import defaultdict from functools import cached_property from typing import Any, Dict, List, Optional, Set, Tuple, Type import torch -from vllm.config import ParallelConfig, SpeculativeConfig +from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig from vllm.distributed.communication_op import broadcast_tensor_dict from vllm.logger import init_logger from vllm.model_executor.layers.rejection_sampler import RejectionSampler @@ -45,8 +46,8 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": """Helper method that is the entrypoint for Executors which use WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config. """ - assert "speculative_config" in kwargs - speculative_config: SpeculativeConfig = kwargs.get("speculative_config") + vllm_config: VllmConfig = kwargs.get("vllm_config") + speculative_config: SpeculativeConfig = vllm_config.speculative_config assert speculative_config is not None draft_worker_kwargs = kwargs.copy() @@ -58,14 +59,16 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": target_worker.model_runner.disable_logprobs =\ speculative_config.disable_logprobs + draft_worker_config = copy.deepcopy(vllm_config) + draft_worker_config.model_config = speculative_config.draft_model_config + draft_worker_config.parallel_config = speculative_config.draft_parallel_config # noqa + # TODO allow draft-model specific load config. + # Override draft-model specific worker args. draft_worker_kwargs.update( - model_config=speculative_config.draft_model_config, - parallel_config=speculative_config.draft_parallel_config, + vllm_config=draft_worker_config, ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max, ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min, - # TODO allow draft-model specific load config. - #load_config=load_config, ) spec_decode_worker = SpecDecodeWorker.create_worker( @@ -134,29 +137,27 @@ def create_worker( draft_worker_kwargs.pop("ngram_prompt_lookup_max")) ngram_prompt_lookup_min = ( draft_worker_kwargs.pop("ngram_prompt_lookup_min")) + draft_model_config = draft_worker_kwargs["vllm_config"].model_config + draft_parallel_config: ParallelConfig = draft_worker_kwargs[ + 'vllm_config'].parallel_config if ngram_prompt_lookup_max > 0: proposer_worker = NGramWorker(**draft_worker_kwargs) proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min, ngram_prompt_lookup_max) else: - draft_parallel_config: ParallelConfig = draft_worker_kwargs[ - 'parallel_config'] draft_tp = draft_parallel_config.tensor_parallel_size target_tp = scorer_worker.parallel_config.tensor_parallel_size - if draft_worker_kwargs[ - "model_config"].hf_config.model_type == "mlp_speculator": + if draft_model_config.hf_config.model_type == "mlp_speculator": proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs) - elif draft_worker_kwargs[ - "model_config"].hf_config.model_type == "medusa": + elif draft_model_config.hf_config.model_type == "medusa": proposer_worker = MedusaWorker(**draft_worker_kwargs) else: if draft_tp == 1: draft_worker_kwargs[ "model_runner_cls"] = TP1DraftModelRunner else: - if draft_worker_kwargs[ - "model_config"].hf_config.model_type == "eagle": + if draft_model_config.hf_config.model_type == "eagle": raise NotImplementedError( "EAGLE does not support TP > 1 yet") @@ -190,8 +191,8 @@ def create_worker( "[Speculative Decoding] Disabling MQA scorer as the " "MQA is only available with flash attn backend.") - if "model_config" in draft_worker_kwargs and \ - draft_worker_kwargs["model_config"].max_model_len < \ + if draft_model_config and \ + draft_model_config.max_model_len < \ scorer_worker.model_config.max_model_len: disable_mqa_scorer = True logger.info( diff --git a/vllm/spec_decode/target_model_runner.py b/vllm/spec_decode/target_model_runner.py index 2bb7af7d7c600..e61cde5b17f20 100644 --- a/vllm/spec_decode/target_model_runner.py +++ b/vllm/spec_decode/target_model_runner.py @@ -1,8 +1,6 @@ from typing import List, Optional -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig) +from vllm.config import VllmConfig from vllm.sequence import SequenceGroupMetadata from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, ModelRunner) @@ -20,35 +18,21 @@ class TargetModelRunner(ModelRunner): requested or not. """ - def __init__(self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - kv_cache_dtype: Optional[str] = "auto", - is_driver_worker: bool = False, - prompt_adapter_config: Optional[PromptAdapterConfig] = None, - return_hidden_states: bool = False, - observability_config: Optional[ObservabilityConfig] = None): + def __init__( + self, + vllm_config: VllmConfig, + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + return_hidden_states: bool = False, + ): # An internal boolean member variable to indicate if token log # probabilities are needed or not. self.disable_logprobs = True super().__init__( - model_config=model_config, - parallel_config=parallel_config, - scheduler_config=scheduler_config, - device_config=device_config, - cache_config=cache_config, - load_config=load_config, - lora_config=lora_config, + vllm_config=vllm_config, kv_cache_dtype=kv_cache_dtype, is_driver_worker=is_driver_worker, - prompt_adapter_config=prompt_adapter_config, return_hidden_states=return_hidden_states, - observability_config=observability_config, ) def prepare_model_input( diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index febabd2f31036..64cc18149d6c5 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -2,8 +2,9 @@ from typing import (Any, Dict, Iterable, List, Mapping, Optional, Tuple, Type, Union) -from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, - ObservabilityConfig, ParallelConfig, SchedulerConfig) +from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, + ObservabilityConfig, ParallelConfig, SchedulerConfig, + VllmConfig) from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics_types import StatLoggerBase from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, @@ -32,7 +33,7 @@ class LLMEngine: def __init__( self, - vllm_config: EngineConfig, + vllm_config: VllmConfig, executor_class: Type[GPUExecutor], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, @@ -477,7 +478,7 @@ def get_lora_config(self) -> LoRAConfig: return self.lora_config @classmethod - def _get_executor_cls(cls, engine_config: EngineConfig): + def _get_executor_cls(cls, engine_config: VllmConfig): return GPUExecutor def is_tracing_enabled(self) -> bool: diff --git a/vllm/v1/executor/gpu_executor.py b/vllm/v1/executor/gpu_executor.py index c780c7031c3d6..b12c500f1f9ee 100644 --- a/vllm/v1/executor/gpu_executor.py +++ b/vllm/v1/executor/gpu_executor.py @@ -56,19 +56,10 @@ def _create_worker( distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) return Worker( - model_config=self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, - load_config=self.load_config, + vllm_config=self.vllm_config, local_rank=local_rank, rank=rank, distributed_init_method=distributed_init_method, - lora_config=self.lora_config, - speculative_config=self.speculative_config, - prompt_adapter_config=self.prompt_adapter_config, - observability_config=self.observability_config, ) def determine_num_available_blocks(self) -> Tuple[int, int]: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e84645ac7a4ae..77c1e10ab6bdf 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -7,9 +7,7 @@ import torch.distributed import torch.nn as nn -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig) +from vllm.config import VllmConfig from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model @@ -33,26 +31,25 @@ class GPUModelRunner: def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig] = None, - prompt_adapter_config: Optional[PromptAdapterConfig] = None, - observability_config: Optional[ObservabilityConfig] = None, + vllm_config: VllmConfig, ): - self.model_config = model_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config - self.lora_config = lora_config - self.load_config = load_config - self.prompt_adapter_config = prompt_adapter_config - self.observability_config = observability_config - + # TODO: use ModelRunnerBase.__init__(self, vllm_config=vllm_config) + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + + model_config = self.model_config + cache_config = self.cache_config + scheduler_config = self.scheduler_config + parallel_config = self.parallel_config self.device = self.device_config.device self.pin_memory = is_pin_memory_available() self.dtype = self.model_config.dtype diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 8c5ca2ec35666..c8192b7f86eb0 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -6,10 +6,7 @@ import torch import torch.distributed -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig, - SpeculativeConfig) +from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) @@ -30,48 +27,35 @@ class Worker: def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, + vllm_config: VllmConfig, local_rank: int, rank: int, distributed_init_method: str, - speculative_config: Optional[SpeculativeConfig] = None, - lora_config: Optional[LoRAConfig] = None, - prompt_adapter_config: Optional[PromptAdapterConfig] = None, - observability_config: Optional[ObservabilityConfig] = None, ): - self.model_config = model_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config - self.load_config = load_config + + # TODO: use WorkerBase.__init__(self, vllm_config=vllm_config) + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method - self.lora_config = lora_config - self.speculative_config = speculative_config - self.prompt_adapter_config = prompt_adapter_config - self.observability_config = observability_config if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils import init_cached_hf_modules init_cached_hf_modules() - self.model_runner = GPUModelRunner( - model_config, - parallel_config, - scheduler_config, - device_config, - cache_config, - load_config, - lora_config=lora_config, - ) + self.model_runner = GPUModelRunner(vllm_config) def initialize(self): if self.device_config.device.type == "cuda": diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 0c6fcdf03ba9e..a98faa2f2d0cb 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -8,9 +8,7 @@ from torch import nn from vllm.attention import AttentionMetadata, get_attn_backend -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, PromptAdapterConfig, - SchedulerConfig) +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding @@ -412,29 +410,18 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], + vllm_config: VllmConfig, kv_cache_dtype: Optional[str] = "auto", - prompt_adapter_config: Optional[PromptAdapterConfig] = None, is_driver_worker: bool = False, *args, **kwargs, ): - self.model_config = model_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config + ModelRunnerBase.__init__(self, vllm_config) # Currently, CPU worker doesn't support chunked prefill. assert self.scheduler_config.chunked_prefill_enabled is False - self.device_config = device_config - self.cache_config = cache_config - self.lora_config = lora_config - self.prompt_adapter_config = prompt_adapter_config - self.load_config = load_config + model_config = self.model_config + cache_config = self.cache_config + self.is_driver_worker = is_driver_worker self.device = self.device_config.device diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index ab93471b5af74..3778707ae07e8 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -6,9 +6,8 @@ import vllm.envs as envs from vllm.attention import get_attn_backend -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, PromptAdapterConfig, - SchedulerConfig) +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, + ParallelConfig, VllmConfig) from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger @@ -18,7 +17,8 @@ from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner from vllm.worker.cpu_model_runner import CPUModelRunner from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, - LoraNotSupportedWorkerBase, WorkerInput) + LoraNotSupportedWorkerBase, WorkerBase, + WorkerInput) logger = init_logger(__name__) @@ -121,31 +121,19 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, + vllm_config: VllmConfig, local_rank: int, rank: int, distributed_init_method: str, - lora_config: Optional[LoRAConfig] = None, kv_cache_dtype: Optional[str] = "auto", - prompt_adapter_config: Optional[PromptAdapterConfig] = None, is_driver_worker: bool = False, ) -> None: - self.model_config = model_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config - self.load_config = load_config + WorkerBase.__init__(self, vllm_config=vllm_config) + self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method - self.lora_config = lora_config - self.prompt_adapter_config = prompt_adapter_config + self.is_driver_worker = is_driver_worker if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." @@ -166,15 +154,8 @@ def __init__( if self._is_encoder_decoder_model(): ModelRunnerClass = CPUEncoderDecoderModelRunner self.model_runner: CPUModelRunner = ModelRunnerClass( - model_config, - parallel_config, - scheduler_config, - device_config, - cache_config, - load_config=self.load_config, - lora_config=self.lora_config, + vllm_config=vllm_config, kv_cache_dtype=kv_cache_dtype, - prompt_adapter_config=self.prompt_adapter_config, is_driver_worker=is_driver_worker) # Uninitialized cache engine. Will be initialized by # initialize_cache. diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index a7f5b2d4fdd1f..ff288d5ca1512 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -3,9 +3,7 @@ import torch -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig) +from vllm.config import VllmConfig from vllm.distributed import get_pp_group from vllm.forward_context import set_forward_context from vllm.logger import init_logger @@ -36,29 +34,13 @@ class EmbeddingModelRunner( def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], + vllm_config: VllmConfig, kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, - prompt_adapter_config: Optional[PromptAdapterConfig] = None, - observability_config: Optional[ObservabilityConfig] = None, ): - super().__init__(model_config, - parallel_config, - scheduler_config, - device_config, - cache_config, - load_config, - lora_config=lora_config, + super().__init__(vllm_config=vllm_config, kv_cache_dtype=kv_cache_dtype, - is_driver_worker=is_driver_worker, - prompt_adapter_config=prompt_adapter_config, - observability_config=observability_config) + is_driver_worker=is_driver_worker) @torch.inference_mode() def execute_model( diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index efbd2ae523d60..6c52fead78a40 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -11,9 +11,7 @@ from vllm.attention.selector import (_Backend, get_env_variable_attn_backend, get_global_forced_attn_backend, global_force_attn_backend) -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig) +from vllm.config import ModelConfig, VllmConfig from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger @@ -86,17 +84,9 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], + vllm_config: VllmConfig, kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, - prompt_adapter_config: Optional[PromptAdapterConfig] = None, - observability_config: Optional[ObservabilityConfig] = None, input_registry: InputRegistry = INPUT_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ): @@ -108,15 +98,10 @@ def __init__( models) but these arguments are present here for compatibility with the base-class constructor. ''' - self._maybe_force_supported_attention_backend(model_config) + self._maybe_force_supported_attention_backend(vllm_config.model_config) + super().__init__( - model_config, - parallel_config, - scheduler_config, - device_config, - cache_config, - load_config, - lora_config=None, + vllm_config=vllm_config, kv_cache_dtype=kv_cache_dtype, is_driver_worker=is_driver_worker, ) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 79e5a14376fa1..c4b1a950ee8f8 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -20,9 +20,7 @@ from vllm.attention.backends.utils import CommonAttentionState from vllm.compilation.compile_context import set_compile_context from vllm.compilation.levels import CompilationLevel -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig) +from vllm.config import VllmConfig from vllm.core.scheduler import SchedulerOutputs from vllm.distributed import get_pp_group from vllm.distributed.parallel_state import graph_capture @@ -955,32 +953,20 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], + vllm_config: VllmConfig, kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, - prompt_adapter_config: Optional[PromptAdapterConfig] = None, return_hidden_states: bool = False, - observability_config: Optional[ObservabilityConfig] = None, input_registry: InputRegistry = INPUT_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ): - self.model_config = model_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config - self.lora_config = lora_config - self.load_config = load_config + + ModelRunnerBase.__init__(self, vllm_config) + model_config = self.model_config + cache_config = self.cache_config + self.is_driver_worker = is_driver_worker - self.prompt_adapter_config = prompt_adapter_config self.return_hidden_states = return_hidden_states - self.observability_config = observability_config self.device = self.device_config.device self.pin_memory = is_pin_memory_available() diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 89d7addb5a8d9..9e529f86b46bb 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -9,6 +9,7 @@ import torch from torch import is_tensor +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput from vllm.platforms import current_platform @@ -220,6 +221,22 @@ class ModelRunnerBase(ABC, Generic[T]): ModelRunnerInputBase subclass. """ + def __init__( + self, + vllm_config: VllmConfig, + ) -> None: + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + # Map of request_id -> generator used for seeded random sampling generators: Dict[str, torch.Generator] = {} diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index be2f0d79154d6..3ee0fb4dc943e 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -304,6 +304,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]): # mypy: enable-error-code=type-var def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs): + super().__init__(*args, **kwargs) # Check attention backend support. diff --git a/vllm/worker/multi_step_worker.py b/vllm/worker/multi_step_worker.py index bf66f32d7d244..1f982fe103366 100644 --- a/vllm/worker/multi_step_worker.py +++ b/vllm/worker/multi_step_worker.py @@ -27,17 +27,9 @@ def __init__(self, *args, **kwargs): # for multi-step model, wrap the model runner with MultiStepModelRunner self.model_runner = MultiStepModelRunner( base_model_runner, - base_model_runner.model_config, - base_model_runner.parallel_config, - base_model_runner.scheduler_config, - base_model_runner.device_config, - base_model_runner.cache_config, - load_config=base_model_runner.load_config, - lora_config=self.lora_config, + vllm_config=base_model_runner.vllm_config, kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=base_model_runner.is_driver_worker, - prompt_adapter_config=base_model_runner.prompt_adapter_config, - observability_config=base_model_runner.observability_config, ) pipeline_parallel_size = self.parallel_config.pipeline_parallel_size diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index b8c760c4b5396..2da22cbfc7cb5 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -7,8 +7,7 @@ from torch import nn from transformers_neuronx.config import GenerationConfig -from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig, - SchedulerConfig) +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput @@ -57,20 +56,13 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, + vllm_config: VllmConfig, ): - self.model_config = model_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - + ModelRunnerBase.__init__(self, vllm_config) + model_config = self.model_config if model_config is not None and model_config.get_sliding_window(): logger.warning("Sliding window is not supported on Neuron. " "The model will run without sliding window.") - self.device_config = (device_config - if device_config is not None else DeviceConfig()) self.device = self.device_config.device self.pin_memory = is_pin_memory_available() diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index fff14d6402b44..3f6269684ac93 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -4,15 +4,15 @@ import torch import torch.distributed -from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, - ParallelConfig, SchedulerConfig) +from vllm.config import VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.model_executor import set_random_seed from vllm.sequence import ExecuteModelRequest from vllm.worker.neuron_model_runner import NeuronModelRunner from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, - LoraNotSupportedWorkerBase, WorkerInput) + LoraNotSupportedWorkerBase, WorkerBase, + WorkerInput) class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): @@ -21,20 +21,12 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, + vllm_config: VllmConfig, local_rank: int, rank: int, distributed_init_method: str, ) -> None: - self.model_config = model_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config + WorkerBase.__init__(self, vllm_config=vllm_config) self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method @@ -44,7 +36,7 @@ def __init__( init_cached_hf_modules() self.model_runner: NeuronModelRunner = NeuronModelRunner( - model_config, parallel_config, scheduler_config, device_config) + vllm_config=vllm_config) self.is_driver_worker = True def init_device(self) -> None: diff --git a/vllm/worker/openvino_model_runner.py b/vllm/worker/openvino_model_runner.py index 3da738636a59d..c9c87ea748081 100644 --- a/vllm/worker/openvino_model_runner.py +++ b/vllm/worker/openvino_model_runner.py @@ -7,9 +7,7 @@ from vllm.attention import get_attn_backend from vllm.attention.backends.openvino import OpenVINOAttentionMetadata -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig) +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput @@ -17,6 +15,7 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalInputs, MultiModalPlaceholderMap) from vllm.sequence import SequenceGroupMetadata +from vllm.worker.model_runner_base import ModelRunnerBase logger = init_logger(__name__) @@ -39,33 +38,21 @@ def empty(cls, device): multi_modal_kwargs={}) -class OpenVINOModelRunner: +class OpenVINOModelRunner(ModelRunnerBase): def __init__( self, ov_core: ov.Core, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - multimodal_config: Optional[MultiModalConfig], + vllm_config: VllmConfig, kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, *args, **kwargs, ): self.ov_core = ov_core - self.model_config = model_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config - self.lora_config = lora_config - self.multimodal_config = multimodal_config - self.load_config = load_config + ModelRunnerBase.__init__(self, vllm_config=vllm_config) + cache_config = self.cache_config + model_config = self.model_config self.is_driver_worker = is_driver_worker self.device = self.device_config.device @@ -369,3 +356,9 @@ def execute_model( sampling_metadata=sampling_metadata, ) return output + + def prepare_model_input(self, *args, **kwargs): + raise NotImplementedError + + def make_model_input_from_broadcasted_tensor_dict(self, *args, **kwargs): + raise NotImplementedError diff --git a/vllm/worker/openvino_worker.py b/vllm/worker/openvino_worker.py index a420d390c1ae4..205f8a337ce6c 100644 --- a/vllm/worker/openvino_worker.py +++ b/vllm/worker/openvino_worker.py @@ -7,9 +7,8 @@ import vllm.envs as envs from vllm.attention import get_attn_backend -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig) +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, + ParallelConfig, VllmConfig) from vllm.distributed import (broadcast_tensor_dict, ensure_model_parallel_initialized, init_distributed_environment) @@ -22,7 +21,7 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata from vllm.worker.openvino_model_runner import OpenVINOModelRunner -from vllm.worker.worker_base import LoraNotSupportedWorkerBase +from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase logger = init_logger(__name__) @@ -212,33 +211,19 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase): def __init__( self, ov_core: ov.Core, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, + vllm_config: VllmConfig, local_rank: int, rank: int, distributed_init_method: str, - lora_config: Optional[LoRAConfig] = None, - multimodal_config: Optional[MultiModalConfig] = None, kv_cache_dtype: Optional[ov.Type] = ov.Type.undefined, is_driver_worker: bool = False, ) -> None: self.ov_core = ov_core - self.model_config = model_config - self.parallel_config = parallel_config + WorkerBase.__init__(self, vllm_config) self.parallel_config.rank = rank - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config - self.load_config = load_config self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method - self.lora_config = lora_config - self.multimodal_config = multimodal_config self.is_driver_worker = is_driver_worker if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." @@ -250,14 +235,7 @@ def __init__( init_cached_hf_modules() self.model_runner = OpenVINOModelRunner( self.ov_core, - model_config, - parallel_config, - scheduler_config, - device_config, - cache_config, - load_config=self.load_config, - lora_config=self.lora_config, - multimodal_config=self.multimodal_config, + vllm_config=self.vllm_config, kv_cache_dtype=kv_cache_dtype, is_driver_worker=is_driver_worker, ) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 3792cbc0f730f..7d9d669a45ce3 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -12,8 +12,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, - ParallelConfig, SchedulerConfig) +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model @@ -90,20 +89,10 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, + vllm_config: VllmConfig, is_driver_worker: bool = False, ): - self.model_config = model_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config - self.load_config = load_config + ModelRunnerBase.__init__(self, vllm_config=vllm_config) self.is_driver_worker = is_driver_worker self.block_size = self.cache_config.block_size diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index de6f7ab0072fd..096cb23416909 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -6,8 +6,7 @@ import torch_xla.runtime as xr import vllm.envs as envs -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, - ParallelConfig, SchedulerConfig) +from vllm.config import VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger @@ -16,7 +15,8 @@ from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size from vllm.worker.tpu_model_runner import TPUModelRunner from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, - LoraNotSupportedWorkerBase, WorkerInput) + LoraNotSupportedWorkerBase, WorkerBase, + WorkerInput) logger = init_logger(__name__) @@ -25,24 +25,14 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, + vllm_config: VllmConfig, local_rank: int, rank: int, distributed_init_method: str, is_driver_worker: bool, ) -> None: - self.model_config = model_config - self.parallel_config = parallel_config + WorkerBase.__init__(self, vllm_config=vllm_config) self.parallel_config.rank = rank - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config - self.load_config = load_config self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method @@ -56,13 +46,7 @@ def __init__( self.cache_config.cache_dtype] self.model_runner: TPUModelRunner = TPUModelRunner( - model_config, - parallel_config, - scheduler_config, - device_config, - cache_config, - load_config, - is_driver_worker=is_driver_worker) + vllm_config=vllm_config, is_driver_worker=is_driver_worker) def init_device(self) -> None: os.environ["PJRT_DEVICE"] = "TPU" diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index d42f9625cc9d6..59b68816366cf 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -8,10 +8,7 @@ import torch.distributed import vllm.envs as envs -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig, - SpeculativeConfig) +from vllm.config import ParallelConfig, VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) @@ -28,7 +25,8 @@ from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner -from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput +from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, + WorkerInput) logger = init_logger(__name__) @@ -43,46 +41,31 @@ class Worker(LocalOrDistributedWorkerBase): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, + vllm_config: VllmConfig, local_rank: int, rank: int, distributed_init_method: str, - lora_config: Optional[LoRAConfig] = None, - speculative_config: Optional[SpeculativeConfig] = None, - prompt_adapter_config: Optional[PromptAdapterConfig] = None, is_driver_worker: bool = False, model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None, - observability_config: Optional[ObservabilityConfig] = None, ) -> None: - self.model_config = model_config - self.parallel_config = parallel_config + WorkerBase.__init__(self, vllm_config) self.parallel_config.rank = rank - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method - self.lora_config = lora_config - self.load_config = load_config - self.prompt_adapter_config = prompt_adapter_config self.is_driver_worker = is_driver_worker - if parallel_config and is_driver_worker: - assert rank % parallel_config.tensor_parallel_size == 0, \ + if is_driver_worker: + assert rank % self.parallel_config.tensor_parallel_size == 0, \ "Driver worker should be rank 0 of tensor parallel group." if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils import init_cached_hf_modules init_cached_hf_modules() - self.observability_config = observability_config # Return hidden states from target model if the draft model is an # mlp_speculator + speculative_config = self.speculative_config + model_config = self.model_config speculative_args = {} if speculative_config is None \ or (speculative_config.draft_model_config.model == model_config.model) \ @@ -98,17 +81,9 @@ def __init__( elif self._is_encoder_decoder_model(): ModelRunnerClass = EncoderDecoderModelRunner self.model_runner: GPUModelRunnerBase = ModelRunnerClass( - model_config, - parallel_config, - scheduler_config, - device_config, - cache_config, - load_config=load_config, - lora_config=self.lora_config, + vllm_config=self.vllm_config, kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=is_driver_worker, - prompt_adapter_config=prompt_adapter_config, - observability_config=observability_config, **speculative_args, ) # Uninitialized cache engine. Will be initialized by diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 6ba4f272315ce..cf8a4946a71c4 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -7,7 +7,7 @@ import torch -from vllm.config import ObservabilityConfig +from vllm.config import ObservabilityConfig, VllmConfig from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -29,6 +29,22 @@ class WorkerBase(ABC): communicate request metadata to other workers. """ + def __init__( + self, + vllm_config: VllmConfig, + ) -> None: + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + @abstractmethod def init_device(self) -> None: """Initialize device state, such as loading the model or other on-device diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 739fe1b3d2c4f..f37d70bee76ed 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -10,9 +10,7 @@ import torch.nn as nn from vllm.attention import get_attn_backend -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig) +from vllm.config import VllmConfig from vllm.distributed import get_pp_group from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger @@ -363,33 +361,18 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], + vllm_config: VllmConfig, kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, - prompt_adapter_config: Optional[PromptAdapterConfig] = None, return_hidden_states: bool = False, - observability_config: Optional[ObservabilityConfig] = None, input_registry: InputRegistry = INPUT_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ): - self.model_config = model_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config - self.lora_config = lora_config - self.load_config = load_config + + ModelRunnerBase.__init__(self, vllm_config=vllm_config) + model_config = self.model_config + cache_config = self.cache_config self.is_driver_worker = is_driver_worker - self.prompt_adapter_config = prompt_adapter_config - self.observability_config = observability_config - if self.observability_config is not None: - print(f"observability_config is {self.observability_config}") self.return_hidden_states = return_hidden_states self.device = self.device_config.device diff --git a/vllm/worker/xpu_worker.py b/vllm/worker/xpu_worker.py index c1d836bb0d318..1295666055b04 100644 --- a/vllm/worker/xpu_worker.py +++ b/vllm/worker/xpu_worker.py @@ -8,10 +8,7 @@ import torch import torch.distributed -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig, - SpeculativeConfig) +from vllm.config import VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger @@ -19,7 +16,7 @@ from vllm.platforms import current_platform from vllm.worker.cache_engine import CacheEngine from vllm.worker.worker import Worker -from vllm.worker.worker_base import LoraNotSupportedWorkerBase +from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase from vllm.worker.xpu_model_runner import XPUModelRunner logger = init_logger(__name__) @@ -36,53 +33,32 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, + vllm_config: VllmConfig, local_rank: int, rank: int, distributed_init_method: str, - lora_config: Optional[LoRAConfig] = None, - speculative_config: Optional[SpeculativeConfig] = None, - prompt_adapter_config: Optional[PromptAdapterConfig] = None, is_driver_worker: bool = False, - observability_config: Optional[ObservabilityConfig] = None, ) -> None: + WorkerBase.__init__(self, vllm_config=vllm_config) + device_config = self.device_config + parallel_config = self.parallel_config assert device_config.device_type == "xpu" assert current_platform.is_xpu() - self.model_config = model_config - self.parallel_config = parallel_config self.parallel_config.rank = rank - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config - self.load_config = load_config + self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method - self.lora_config = lora_config - self.prompt_adapter_config = prompt_adapter_config self.is_driver_worker = is_driver_worker - self.observability_config = observability_config if parallel_config and is_driver_worker: assert rank % parallel_config.tensor_parallel_size == 0, \ "Driver worker should be rank 0 of tensor parallel group." self.model_runner = XPUModelRunner( # type: ignore - model_config, - parallel_config, - scheduler_config, - device_config, - cache_config, - load_config=self.load_config, - lora_config=self.lora_config, + vllm_config=vllm_config, kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=is_driver_worker, - observability_config=self.observability_config, ) # Uninitialized cache engine. Will be initialized by # initialize_cache.