From a111d0151ffed94582bec65635979e04e5b63676 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 21 Nov 2024 21:00:32 -0800 Subject: [PATCH] [platforms] absorb worker cls difference into platforms folder (#10555) Signed-off-by: youkaichao Co-authored-by: Nick Hill --- vllm/config.py | 238 ++++++++++++------------ vllm/engine/arg_utils.py | 11 +- vllm/executor/cpu_executor.py | 7 +- vllm/executor/gpu_executor.py | 49 +---- vllm/executor/hpu_executor.py | 5 +- vllm/executor/multiproc_gpu_executor.py | 2 +- vllm/executor/neuron_executor.py | 5 +- vllm/executor/openvino_executor.py | 8 +- vllm/executor/ray_gpu_executor.py | 16 +- vllm/executor/ray_hpu_executor.py | 36 +--- vllm/executor/ray_tpu_executor.py | 19 +- vllm/executor/xpu_executor.py | 14 +- vllm/platforms/cpu.py | 2 + vllm/platforms/cuda.py | 21 ++- vllm/platforms/hpu.py | 23 +++ vllm/platforms/neuron.py | 14 ++ vllm/platforms/openvino.py | 18 ++ vllm/platforms/rocm.py | 20 ++ vllm/platforms/tpu.py | 12 ++ vllm/platforms/xpu.py | 6 + vllm/worker/worker_base.py | 30 +-- 21 files changed, 273 insertions(+), 283 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index d1c6a850cb78c..b5f2116e3557b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -926,56 +926,56 @@ def _verify_load_format(self) -> None: f"{rocm_supported_load_format}") +@dataclass class ParallelConfig: - """Configuration for the distributed execution. + """Configuration for the distributed execution.""" - Args: - pipeline_parallel_size: Number of pipeline parallel groups. - tensor_parallel_size: Number of tensor parallel groups. - worker_use_ray: Deprecated, use distributed_executor_backend instead. - max_parallel_loading_workers: Maximum number of multiple batches - when load model sequentially. To avoid RAM OOM when using tensor - parallel and large models. - disable_custom_all_reduce: Disable the custom all-reduce kernel and - fall back to NCCL. - tokenizer_pool_config: Config for the tokenizer pool. - If None, will use synchronous tokenization. - ray_workers_use_nsight: Whether to profile Ray workers with nsight, see - https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler. - placement_group: ray distributed model workers placement group. - distributed_executor_backend: Backend to use for distributed model - workers, either "ray" or "mp" (multiprocessing). If the product - of pipeline_parallel_size and tensor_parallel_size is less than - or equal to the number of GPUs available, "mp" will be used to - keep processing on a single host. Otherwise, this will default - to "ray" if Ray is installed and fail otherwise. Note that tpu - and hpu only support Ray for distributed inference. - """ + pipeline_parallel_size: int = 1 # Number of pipeline parallel groups. + tensor_parallel_size: int = 1 # Number of tensor parallel groups. - def __init__( - self, - pipeline_parallel_size: int, - tensor_parallel_size: int, - worker_use_ray: Optional[bool] = None, - max_parallel_loading_workers: Optional[int] = None, - disable_custom_all_reduce: bool = False, - tokenizer_pool_config: Optional[TokenizerPoolConfig] = None, - ray_workers_use_nsight: bool = False, - placement_group: Optional["PlacementGroup"] = None, - distributed_executor_backend: Optional[Union[ - str, Type["ExecutorBase"]]] = None, - ) -> None: - self.pipeline_parallel_size = pipeline_parallel_size - self.tensor_parallel_size = tensor_parallel_size - self.distributed_executor_backend = distributed_executor_backend - self.max_parallel_loading_workers = max_parallel_loading_workers - self.disable_custom_all_reduce = disable_custom_all_reduce - self.tokenizer_pool_config = tokenizer_pool_config - self.ray_workers_use_nsight = ray_workers_use_nsight - self.placement_group = placement_group - self.world_size = pipeline_parallel_size * self.tensor_parallel_size - - if worker_use_ray: + # Deprecated, use distributed_executor_backend instead. + worker_use_ray: Optional[bool] = None + + # Maximum number of multiple batches + # when load model sequentially. To avoid RAM OOM when using tensor + # parallel and large models. + max_parallel_loading_workers: Optional[int] = None + + # Disable the custom all-reduce kernel and fall back to NCCL. + disable_custom_all_reduce: bool = False + + # Config for the tokenizer pool. If None, will use synchronous tokenization. + tokenizer_pool_config: Optional[TokenizerPoolConfig] = None + + # Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler. + ray_workers_use_nsight: bool = False + + # ray distributed model workers placement group. + placement_group: Optional["PlacementGroup"] = None + + # Backend to use for distributed model + # workers, either "ray" or "mp" (multiprocessing). If the product + # of pipeline_parallel_size and tensor_parallel_size is less than + # or equal to the number of GPUs available, "mp" will be used to + # keep processing on a single host. Otherwise, this will default + # to "ray" if Ray is installed and fail otherwise. Note that tpu + # and hpu only support Ray for distributed inference. + distributed_executor_backend: Optional[Union[str, + Type["ExecutorBase"]]] = None + + # the full name of the worker class to use. If "auto", the worker class + # will be determined based on the platform. + worker_cls: str = "auto" + + world_size: int = field(init=False) + + rank: int = 0 + + def __post_init__(self) -> None: + self.world_size = self.pipeline_parallel_size * \ + self.tensor_parallel_size + + if self.worker_use_ray: if self.distributed_executor_backend is None: self.distributed_executor_backend = "ray" elif not self.use_ray: @@ -1026,7 +1026,6 @@ def __init__( backend) self._verify_args() - self.rank: int = 0 @property def use_ray(self) -> bool: @@ -1059,100 +1058,97 @@ def _verify_args(self) -> None: "run with Ray.") +@dataclass class SchedulerConfig: - """Scheduler configuration. + """Scheduler configuration.""" - Args: - task: The task to use the model for. - max_num_batched_tokens: Maximum number of tokens to be processed in - a single iteration. - max_num_seqs: Maximum number of sequences to be processed in a single - iteration. - max_model_len: Maximum length of a sequence (including prompt - and generated text). - num_lookahead_slots: The number of slots to allocate per sequence per - step, beyond the known token ids. This is used in speculative - decoding to store KV activations of tokens which may or may not be - accepted. - delay_factor: Apply a delay (of delay factor multiplied by previous - prompt latency) before scheduling next prompt. - enable_chunked_prefill: If True, prefill requests can be chunked based - on the remaining max_num_batched_tokens. - preemption_mode: Whether to perform preemption by swapping or - recomputation. If not specified, we determine the mode as follows: - We use recomputation by default since it incurs lower overhead than - swapping. However, when the sequence group has multiple sequences - (e.g., beam search), recomputation is not currently supported. In - such a case, we use swapping instead. - send_delta_data: Private API. If used, scheduler sends delta data to - workers instead of an entire data. It should be enabled only - when SPMD worker architecture is enabled. I.e., - VLLM_USE_RAY_SPMD_WORKER=1 - policy: The scheduling policy to use. "fcfs" (default) or "priority". - """ + task: str = "generate" # The task to use the model for. + + # Maximum number of tokens to be processed in a single iteration. + max_num_batched_tokens: int = field(default=None) # type: ignore + + # Maximum number of sequences to be processed in a single iteration. + max_num_seqs: int = 128 + + # Maximum length of a sequence (including prompt and generated text). + max_model_len: int = 8192 + + # The number of slots to allocate per sequence per + # step, beyond the known token ids. This is used in speculative + # decoding to store KV activations of tokens which may or may not be + # accepted. + num_lookahead_slots: int = 0 + + # Apply a delay (of delay factor multiplied by previous + # prompt latency) before scheduling next prompt. + delay_factor: float = 0.0 + + # If True, prefill requests can be chunked based + # on the remaining max_num_batched_tokens. + enable_chunked_prefill: bool = False + + is_multimodal_model: bool = False - def __init__(self, - task: _Task, - max_num_batched_tokens: Optional[int], - max_num_seqs: int, - max_model_len: int, - num_lookahead_slots: int = 0, - delay_factor: float = 0.0, - enable_chunked_prefill: bool = False, - is_multimodal_model: bool = False, - preemption_mode: Optional[str] = None, - num_scheduler_steps: int = 1, - multi_step_stream_outputs: bool = False, - send_delta_data: bool = False, - policy: str = "fcfs") -> None: - if max_num_batched_tokens is None: - if enable_chunked_prefill: - if num_scheduler_steps > 1: + # Whether to perform preemption by swapping or + # recomputation. If not specified, we determine the mode as follows: + # We use recomputation by default since it incurs lower overhead than + # swapping. However, when the sequence group has multiple sequences + # (e.g., beam search), recomputation is not currently supported. In + # such a case, we use swapping instead. + preemption_mode: Optional[str] = None + + num_scheduler_steps: int = 1 + + multi_step_stream_outputs: bool = False + + # Private API. If used, scheduler sends delta data to + # workers instead of an entire data. It should be enabled only + # when SPMD worker architecture is enabled. I.e., + # VLLM_USE_RAY_SPMD_WORKER=1 + send_delta_data: bool = False + + # The scheduling policy to use. "fcfs" (default) or "priority". + policy: str = "fcfs" + + chunked_prefill_enabled: bool = field(init=False) + + def __post_init__(self) -> None: + if self.max_num_batched_tokens is None: + if self.enable_chunked_prefill: + if self.num_scheduler_steps > 1: # Multi-step Chunked-Prefill doesn't allow prompt-chunking # for now. Have max_num_batched_tokens set to max_model_len # so we don't reject sequences on account of a short # max_num_batched_tokens. - max_num_batched_tokens = max(max_model_len, 2048) + self.max_num_batched_tokens = max(self.max_model_len, 2048) else: # It is the values that have the best balance between ITL # and TTFT on A100. Note it is not optimized for throughput. - max_num_batched_tokens = 512 + self.max_num_batched_tokens = 512 else: # If max_model_len is too short, use 2048 as the default value # for higher throughput. - max_num_batched_tokens = max(max_model_len, 2048) + self.max_num_batched_tokens = max(self.max_model_len, 2048) - if task == "embedding": + if self.task == "embedding": # For embedding, choose specific value for higher throughput - max_num_batched_tokens = max( - max_num_batched_tokens, + self.max_num_batched_tokens = max( + self.max_num_batched_tokens, _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS, ) - if is_multimodal_model: + if self.is_multimodal_model: # The value needs to be at least the number of multimodal tokens - max_num_batched_tokens = max( - max_num_batched_tokens, + self.max_num_batched_tokens = max( + self.max_num_batched_tokens, _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, ) - self.max_num_batched_tokens = max_num_batched_tokens - - if enable_chunked_prefill: + if self.enable_chunked_prefill: logger.info( "Chunked prefill is enabled with max_num_batched_tokens=%d.", self.max_num_batched_tokens) - self.task: Final = task - self.max_num_seqs = max_num_seqs - self.max_model_len = max_model_len - self.num_lookahead_slots = num_lookahead_slots - self.delay_factor = delay_factor - self.chunked_prefill_enabled = enable_chunked_prefill - self.preemption_mode = preemption_mode - self.num_scheduler_steps = num_scheduler_steps - self.multi_step_stream_outputs = multi_step_stream_outputs - self.send_delta_data = send_delta_data - self.policy = policy + self.chunked_prefill_enabled = self.enable_chunked_prefill self._verify_args() def _verify_args(self) -> None: @@ -2293,10 +2289,10 @@ class VllmConfig: model_config: ModelConfig = field(default=None, init=True) # type: ignore cache_config: CacheConfig = field(default=None, init=True) # type: ignore - parallel_config: ParallelConfig = field(default=None, - init=True) # type: ignore - scheduler_config: SchedulerConfig = field(default=None, - init=True) # type: ignore + parallel_config: ParallelConfig = field(default_factory=ParallelConfig, + init=True) + scheduler_config: SchedulerConfig = field(default_factory=SchedulerConfig, + init=True) device_config: DeviceConfig = field(default=None, init=True) # type: ignore load_config: LoadConfig = field(default=None, init=True) # type: ignore diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 88862a185ac75..82f1ef51255e9 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -191,6 +191,7 @@ class EngineArgs: override_neuron_config: Optional[Dict[str, Any]] = None override_pooler_config: Optional[PoolerConfig] = None compilation_config: Optional[CompilationConfig] = None + worker_cls: str = "auto" def __post_init__(self): if not self.tokenizer: @@ -887,6 +888,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'compilers, using -O without space is also ' 'supported. -O3 is equivalent to -O 3.') + parser.add_argument( + '--worker-cls', + type=str, + default="auto", + help='The worker class to use for distributed execution.') + return parser @classmethod @@ -999,7 +1006,9 @@ def create_engine_config(self) -> VllmConfig: self.tokenizer_pool_extra_config, ), ray_workers_use_nsight=self.ray_workers_use_nsight, - distributed_executor_backend=self.distributed_executor_backend) + distributed_executor_backend=self.distributed_executor_backend, + worker_cls=self.worker_cls, + ) max_model_len = model_config.max_model_len use_long_context = max_model_len > 32768 diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 1542a2ae367eb..336f9bc8efb20 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -115,13 +115,8 @@ def _create_worker( local_rank: int = 0, rank: int = 0, ): - worker_module_name = "vllm.worker.cpu_worker" - worker_class_name = "CPUWorker" - wrapper = WorkerWrapperBase( - worker_module_name=worker_module_name, - worker_class_name=worker_class_name, - ) + wrapper = WorkerWrapperBase(vllm_config=self.vllm_config) assert self.distributed_init_method is not None diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index c65d0836e5ff7..7fa34456028dd 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger @@ -8,19 +8,14 @@ from vllm.sequence import ExecuteModelRequest, PoolerOutput from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) -from vllm.worker.worker_base import WorkerBase, WorkerWrapperBase +from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) -def create_worker(worker_module_name: str, worker_class_name: str, - worker_class_fn: Optional[Callable[[], Type[WorkerBase]]], - **kwargs): - wrapper = WorkerWrapperBase( - worker_module_name=worker_module_name, - worker_class_name=worker_class_name, - worker_class_fn=worker_class_fn, - ) +def create_worker(**kwargs): + vllm_config = kwargs.get("vllm_config") + wrapper = WorkerWrapperBase(vllm_config=vllm_config) wrapper.init_worker(**kwargs) return wrapper.worker @@ -57,43 +52,11 @@ def _get_worker_kwargs( or (rank % self.parallel_config.tensor_parallel_size == 0), ) - def _get_worker_module_and_class( - self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]: - worker_class_fn = None - if self.scheduler_config.is_multi_step: - worker_module_name = "vllm.worker.multi_step_worker" - worker_class_name = "MultiStepWorker" - elif self.speculative_config: - worker_module_name = "vllm.spec_decode.spec_decode_worker" - worker_class_name = "create_spec_worker" - else: - worker_module_name = "vllm.worker.worker" - worker_class_name = "Worker" - return (worker_module_name, worker_class_name, worker_class_fn) - - def _get_create_worker_kwargs( - self, - local_rank: int = 0, - rank: int = 0, - distributed_init_method: Optional[str] = None) -> Dict: - worker_kwargs = self._get_worker_kwargs(local_rank, rank, - distributed_init_method) - - (worker_module_name, worker_class_name, - worker_class_fn) = self._get_worker_module_and_class() - worker_kwargs.update( - worker_module_name=worker_module_name, - worker_class_name=worker_class_name, - worker_class_fn=worker_class_fn, - ) - - return worker_kwargs - def _create_worker(self, local_rank: int = 0, rank: int = 0, distributed_init_method: Optional[str] = None): - return create_worker(**self._get_create_worker_kwargs( + return create_worker(**self._get_worker_kwargs( local_rank=local_rank, rank=rank, distributed_init_method=distributed_init_method)) diff --git a/vllm/executor/hpu_executor.py b/vllm/executor/hpu_executor.py index 220e9eee87bb3..c9b7bfa71edfa 100644 --- a/vllm/executor/hpu_executor.py +++ b/vllm/executor/hpu_executor.py @@ -48,10 +48,7 @@ def _create_worker(self, local_rank: int = 0, rank: int = 0, distributed_init_method: Optional[str] = None): - wrapper = WorkerWrapperBase( - worker_module_name="vllm.worker.hpu_worker", - worker_class_name="HPUWorker", - ) + wrapper = WorkerWrapperBase(vllm_config=self.vllm_config) wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank, distributed_init_method)) return wrapper.worker diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index 3eb14fb931925..a6c05a71d2b6f 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -90,7 +90,7 @@ def _init_executor(self) -> None: result_handler, partial( create_worker, - **self._get_create_worker_kwargs( + **self._get_worker_kwargs( rank=rank, local_rank=rank, distributed_init_method=distributed_init_method, diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index 02d37cd7fbf23..31e6fdc3ab1bb 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -7,6 +7,7 @@ from vllm.sequence import ExecuteModelRequest from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) +from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -25,10 +26,10 @@ def _init_executor(self) -> None: self._init_worker() def _init_worker(self): - from vllm.worker.neuron_worker import NeuronWorker + wrapper = WorkerWrapperBase(vllm_config=self.vllm_config) distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) - self.driver_worker = NeuronWorker( + self.driver_worker = wrapper.init_worker( vllm_config=self.vllm_config, local_rank=0, rank=0, diff --git a/vllm/executor/openvino_executor.py b/vllm/executor/openvino_executor.py index d06b0ccb7906e..dcd4b7621381d 100644 --- a/vllm/executor/openvino_executor.py +++ b/vllm/executor/openvino_executor.py @@ -14,6 +14,7 @@ from vllm.sequence import ExecuteModelRequest from vllm.utils import (GiB_bytes, get_distributed_init_method, get_ip, get_open_port, make_async) +from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -38,15 +39,12 @@ def _init_executor(self) -> None: self._init_worker() def _init_worker(self): - from vllm.worker.openvino_worker import OpenVINOWorker - assert ( - self.parallel_config.world_size == 1 - ), "OpenVINOExecutor only supports single CPU socket currently." + wrapper = WorkerWrapperBase(vllm_config=self.vllm_config) distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) - self.driver_worker = OpenVINOWorker( + self.driver_worker = wrapper.init_worker( ov_core=self.ov_core, vllm_config=self.vllm_config, local_rank=0, diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 66bab2c686c67..810b0f06ff7b2 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -91,17 +91,6 @@ def _configure_ray_workers_use_nsight(self, return ray_remote_kwargs - def _get_worker_wrapper_args(self) -> Dict[str, Any]: - (worker_module_name, worker_class_name, - worker_class_fn) = self._get_worker_module_and_class() - - return dict( - worker_module_name=worker_module_name, - worker_class_name=worker_class_name, - worker_class_fn=worker_class_fn, - trust_remote_code=self.model_config.trust_remote_code, - ) - # child class could overwrite this to return actual env vars. def _get_env_vars_to_be_updated(self): return self._env_vars_for_all_workers @@ -135,7 +124,6 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # Create the workers. driver_ip = get_ip() - worker_wrapper_kwargs = self._get_worker_wrapper_args() for bundle_id, bundle in enumerate(placement_group.bundle_specs): if not bundle.get("GPU", 0): continue @@ -150,7 +138,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", num_gpus=num_gpus, scheduling_strategy=scheduling_strategy, **ray_remote_kwargs, - )(RayWorkerWrapper).remote(**worker_wrapper_kwargs) + )(RayWorkerWrapper).remote(vllm_config=self.vllm_config) if self.use_ray_spmd_worker: self.workers.append(worker) @@ -161,7 +149,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # as the resource holder for the driver process. self.driver_dummy_worker = worker self.driver_worker = RayWorkerWrapper( - **worker_wrapper_kwargs) + vllm_config=self.vllm_config) else: # Else, added to the list of workers. self.workers.append(worker) diff --git a/vllm/executor/ray_hpu_executor.py b/vllm/executor/ray_hpu_executor.py index a24bab6df370e..6fe8c6c403358 100644 --- a/vllm/executor/ray_hpu_executor.py +++ b/vllm/executor/ray_hpu_executor.py @@ -2,8 +2,7 @@ import os from collections import defaultdict from itertools import islice, repeat -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, - Type) +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import msgspec @@ -18,7 +17,6 @@ from vllm.utils import (_run_task_with_lock, get_distributed_init_method, get_ip, get_open_port, get_vllm_instance_id, make_async) -from vllm.worker.worker_base import WorkerBase if ray is not None: from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy @@ -81,33 +79,6 @@ def shutdown(self) -> None: def finish_measurements(self): self._run_workers("finish_measurements") - def _get_worker_module_and_class( - self - ) -> Tuple[str, str, Optional[Callable[[], - Type[WorkerBase]]]]: # noqa: F821 - worker_class_fn = None - if self.scheduler_config.is_multi_step: - raise NotImplementedError( - "Multi-step execution is not implemented for HPU") - elif self.speculative_config: - raise NotImplementedError( - "Speculative decoding is not implemented for HPU") - else: - worker_module_name = "vllm.worker.hpu_worker" - worker_class_name = "HPUWorker" - return (worker_module_name, worker_class_name, worker_class_fn) - - def _get_worker_wrapper_args(self) -> Dict[str, Any]: - (worker_module_name, worker_class_name, - worker_class_fn) = self._get_worker_module_and_class() - - return dict( - worker_module_name=worker_module_name, - worker_class_name=worker_class_name, - worker_class_fn=worker_class_fn, - trust_remote_code=self.model_config.trust_remote_code, - ) - def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs): # Otherwise, the ray workers are allocated with a full GPU. @@ -128,7 +99,6 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # Create the workers. driver_ip = get_ip() - worker_wrapper_kwargs = self._get_worker_wrapper_args() for bundle_id, bundle in enumerate(placement_group.bundle_specs): if not bundle.get("HPU", 0): continue @@ -144,7 +114,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", resources={'HPU': num_gpus}, scheduling_strategy=scheduling_strategy, **ray_remote_kwargs, - )(RayWorkerWrapper).remote(**worker_wrapper_kwargs) + )(RayWorkerWrapper).remote(vllm_config=self.vllm_config) if self.use_ray_spmd_worker: self.workers.append(worker) @@ -155,7 +125,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # as the resource holder for the driver process. self.driver_dummy_worker = worker self.driver_worker = RayWorkerWrapper( - **worker_wrapper_kwargs) + vllm_config=self.vllm_config) else: # Else, added to the list of workers. self.workers.append(worker) diff --git a/vllm/executor/ray_tpu_executor.py b/vllm/executor/ray_tpu_executor.py index d02fecb46f007..c227b5e283c68 100644 --- a/vllm/executor/ray_tpu_executor.py +++ b/vllm/executor/ray_tpu_executor.py @@ -69,14 +69,6 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", placement_group_bundle_index=bundle_id, ) - assert self.speculative_config is None - if self.scheduler_config.is_multi_step: - worker_module_name = "vllm.worker.multi_step_tpu_worker" - worker_class_name = "MultiStepTPUWorker" - else: - worker_module_name = "vllm.worker.tpu_worker" - worker_class_name = "TPUWorker" - # GKE does not fetch environment information from metadata server # and instead sets these from within the Ray process. Therefore we # need to override the Ray environment variables manually. @@ -95,11 +87,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", resources={"TPU": 1}, scheduling_strategy=scheduling_strategy, **ray_remote_kwargs, - )(RayWorkerWrapper).remote( - worker_module_name=worker_module_name, - worker_class_name=worker_class_name, - trust_remote_code=self.model_config.trust_remote_code, - ) + )(RayWorkerWrapper).remote(vllm_config=self.vllm_config) if override_env: worker.override_env_vars.remote(override_env) @@ -109,10 +97,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # as the resource holder for the driver process. self.driver_dummy_worker = worker self.driver_worker = RayWorkerWrapper( - worker_module_name=worker_module_name, - worker_class_name=worker_class_name, - trust_remote_code=self.model_config.trust_remote_code, - ) + vllm_config=self.vllm_config) else: # Else, added to the list of workers. self.workers.append(worker) diff --git a/vllm/executor/xpu_executor.py b/vllm/executor/xpu_executor.py index ba6177e51a453..722b86a95ff8a 100644 --- a/vllm/executor/xpu_executor.py +++ b/vllm/executor/xpu_executor.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Optional, Tuple, Type, Union +from typing import List, Optional, Union from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.gpu_executor import GPUExecutor @@ -6,7 +6,6 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import ExecuteModelRequest, PoolerOutput from vllm.utils import make_async -from vllm.worker.worker_base import WorkerBase logger = init_logger(__name__) @@ -22,17 +21,6 @@ def _init_executor(self) -> None: GPUExecutor._init_executor(self) - def _get_worker_module_and_class( - self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]: - worker_class_fn = None - if self.speculative_config is not None: - raise NotImplementedError( - "XPU does not support speculative decoding") - else: - worker_module_name = "vllm.worker.xpu_worker" - worker_class_name = "XPUWorker" - return (worker_module_name, worker_class_name, worker_class_fn) - def execute_model( self, execute_model_req: ExecuteModelRequest ) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]: diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 0c4c916406223..9be9031dc3baf 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -84,3 +84,5 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "distributed executor backend."), parallel_config.distributed_executor_backend) parallel_config.distributed_executor_backend = "mp" + if parallel_config.worker_cls == "auto": + parallel_config.worker_cls = "vllm.worker.cpu_worker.CPUWorker" diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index b38dd7c936896..cf0d41081a5aa 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -4,7 +4,7 @@ import os from functools import lru_cache, wraps -from typing import Callable, List, Tuple, TypeVar +from typing import TYPE_CHECKING, Callable, List, Tuple, TypeVar import pynvml import torch @@ -16,6 +16,11 @@ from .interface import DeviceCapability, Platform, PlatformEnum +if TYPE_CHECKING: + from vllm.config import VllmConfig +else: + VllmConfig = None + logger = init_logger(__name__) _P = ParamSpec("_P") @@ -157,3 +162,17 @@ def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool: " machine has no NVLink equipped.") return False return True + + @classmethod + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + parallel_config = vllm_config.parallel_config + scheduler_config = vllm_config.scheduler_config + if parallel_config.worker_cls == "auto": + if scheduler_config.is_multi_step: + parallel_config.worker_cls = \ + "vllm.worker.multi_step_worker.MultiStepWorker" + elif vllm_config.speculative_config: + parallel_config.worker_cls = \ + "vllm.spec_decode.spec_decode_worker.create_spec_worker" + else: + parallel_config.worker_cls = "vllm.worker.worker.Worker" diff --git a/vllm/platforms/hpu.py b/vllm/platforms/hpu.py index 36d944b3f24b8..a8f568d31d5a7 100644 --- a/vllm/platforms/hpu.py +++ b/vllm/platforms/hpu.py @@ -1,7 +1,14 @@ +from typing import TYPE_CHECKING + import torch from .interface import Platform, PlatformEnum, _Backend +if TYPE_CHECKING: + from vllm.config import VllmConfig +else: + VllmConfig = None + class HpuPlatform(Platform): _enum = PlatformEnum.HPU @@ -14,3 +21,19 @@ def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: @staticmethod def inference_mode(): return torch.no_grad() + + @classmethod + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + + scheduler_config = vllm_config.scheduler_config + if scheduler_config.is_multi_step: + raise NotImplementedError( + "Multi-step execution is not implemented for HPU") + + if vllm_config.speculative_config is not None: + raise NotImplementedError( + "Speculative decoding is not implemented for HPU") + + parallel_config = vllm_config.parallel_config + if parallel_config.worker_cls == "auto": + parallel_config.worker_cls = "vllm.worker.hpu_worker.HPUWorker" diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py index 57e3c0dfae84c..4c4d778ed3dd4 100644 --- a/vllm/platforms/neuron.py +++ b/vllm/platforms/neuron.py @@ -1,5 +1,12 @@ +from typing import TYPE_CHECKING + from .interface import Platform, PlatformEnum +if TYPE_CHECKING: + from vllm.config import VllmConfig +else: + VllmConfig = None + class NeuronPlatform(Platform): _enum = PlatformEnum.NEURON @@ -8,3 +15,10 @@ class NeuronPlatform(Platform): @classmethod def get_device_name(cls, device_id: int = 0) -> str: return "neuron" + + @classmethod + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + parallel_config = vllm_config.parallel_config + if parallel_config.worker_cls == "auto": + parallel_config.worker_cls = \ + "vllm.worker.neuron_worker.NeuronWorker" diff --git a/vllm/platforms/openvino.py b/vllm/platforms/openvino.py index 130b8eec1b386..33a41933e9fff 100644 --- a/vllm/platforms/openvino.py +++ b/vllm/platforms/openvino.py @@ -1,3 +1,5 @@ +from typing import TYPE_CHECKING + import torch import vllm.envs as envs @@ -5,6 +7,11 @@ from .interface import Platform, PlatformEnum, _Backend +if TYPE_CHECKING: + from vllm.config import VllmConfig +else: + VllmConfig = None + logger = init_logger(__name__) @@ -38,3 +45,14 @@ def is_openvino_gpu(self) -> bool: def is_pin_memory_available(self) -> bool: logger.warning("Pin memory is not supported on OpenViNO.") return False + + @classmethod + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + parallel_config = vllm_config.parallel_config + assert ( + parallel_config.world_size == 1 + ), "OpenVINOExecutor only supports single CPU socket currently." + + if parallel_config.worker_cls == "auto": + parallel_config.worker_cls = \ + "vllm.worker.openvino_worker.OpenVINOWorker" diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index c62241d8bb47b..3fe8c01c15787 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -1,5 +1,6 @@ import os from functools import lru_cache +from typing import TYPE_CHECKING import torch @@ -7,6 +8,11 @@ from .interface import DeviceCapability, Platform, PlatformEnum, _Backend +if TYPE_CHECKING: + from vllm.config import VllmConfig +else: + VllmConfig = None + logger = init_logger(__name__) try: @@ -58,3 +64,17 @@ def get_device_name(cls, device_id: int = 0) -> str: def get_device_total_memory(cls, device_id: int = 0) -> int: device_props = torch.cuda.get_device_properties(device_id) return device_props.total_memory + + @classmethod + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + parallel_config = vllm_config.parallel_config + scheduler_config = vllm_config.scheduler_config + if parallel_config.worker_cls == "auto": + if scheduler_config.is_multi_step: + parallel_config.worker_cls = \ + "vllm.worker.multi_step_worker.MultiStepWorker" + elif vllm_config.speculative_config: + parallel_config.worker_cls = \ + "vllm.spec_decode.spec_decode_worker.create_spec_worker" + else: + parallel_config.worker_cls = "vllm.worker.worker.Worker" diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 863875ef5c2d6..513cfa54687dc 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -48,3 +48,15 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if compilation_config.backend == "": compilation_config.backend = "openxla" + + assert vllm_config.speculative_config is None, \ + "TPU does not support speculative decoding" + + parallel_config = vllm_config.parallel_config + scheduler_config = vllm_config.scheduler_config + if parallel_config.worker_cls == "auto": + if scheduler_config.is_multi_step: + parallel_config.worker_cls = \ + "vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker" + else: + parallel_config.worker_cls = "vllm.worker.tpu_worker.TPUWorker" diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 536e17a5f93e8..b2ee0ef2f71cd 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -57,6 +57,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "mode.") model_config.enforce_eager = True + if vllm_config.speculative_config is not None: + raise NotImplementedError( + "XPU does not support speculative decoding") + # check and update parallel config parallel_config = vllm_config.parallel_config if (parallel_config.distributed_executor_backend is not None @@ -66,3 +70,5 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: " executor backend.", parallel_config.distributed_executor_backend) parallel_config.distributed_executor_backend = "ray" + if parallel_config.worker_cls == "auto": + parallel_config.worker_cls = "vllm.worker.xpu_worker.XPUWorker" diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index cf8a4946a71c4..e7fec6d17eecd 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -1,9 +1,8 @@ import dataclasses -import importlib import os import time from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union import torch @@ -15,7 +14,7 @@ from vllm.platforms import current_platform from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.utils import (enable_trace_function_call_for_thread, - update_environment_variables) + resolve_obj_by_qualname, update_environment_variables) from vllm.worker.model_runner_base import (BroadcastableModelInput, ModelRunnerBase, ModelRunnerInputBase) @@ -411,23 +410,14 @@ class WorkerWrapperBase: We first instantiate the WorkerWrapper, which remembers the worker module and class name. Then, when we call `update_environment_variables`, and the real initialization happens in `init_worker`. - - If worker_class_fn is specified, it will be executed to get the worker - class. - Otherwise, the worker class will be obtained by dynamically importing it - using worker_module_name and worker_class_name. """ def __init__( self, - worker_module_name: str, - worker_class_name: str, - trust_remote_code: bool = False, - worker_class_fn: Optional[Callable[[], - Type[WorkerBase]]] = None) -> None: - self.worker_module_name = worker_module_name - self.worker_class_name = worker_class_name - self.worker_class_fn = worker_class_fn + vllm_config: VllmConfig, + ) -> None: + self.vllm_config = vllm_config + trust_remote_code = vllm_config.model_config.trust_remote_code self.worker: Optional[WorkerBase] = None if trust_remote_code: # note: lazy import to avoid importing torch before initializing @@ -456,12 +446,8 @@ def init_worker(self, *args, **kwargs): from vllm.plugins import load_general_plugins load_general_plugins() - if self.worker_class_fn: - worker_class = self.worker_class_fn() - else: - mod = importlib.import_module(self.worker_module_name) - worker_class = getattr(mod, self.worker_class_name) - + worker_class = resolve_obj_by_qualname( + self.vllm_config.parallel_config.worker_cls) self.worker = worker_class(*args, **kwargs) assert self.worker is not None