From 23c8419fa4677ec2abbdc194a70449dbf03b43f8 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Fri, 19 Jul 2024 18:25:06 -0700 Subject: [PATCH] [Core] Allow specifying custom Executor (#6557) Signed-off-by: Alvant --- tests/conftest.py | 4 + tests/engine/test_custom_executor.py | 91 +++++++++++++++++++ tests/tokenization/test_tokenizer_group.py | 21 ++++- vllm/config.py | 39 +++++--- vllm/engine/arg_utils.py | 18 +++- vllm/engine/async_llm_engine.py | 53 +++++++---- vllm/engine/llm_engine.py | 40 +++++--- vllm/executor/cpu_executor.py | 2 + vllm/executor/executor_base.py | 2 + vllm/executor/gpu_executor.py | 2 + vllm/executor/multiproc_gpu_executor.py | 2 + vllm/executor/neuron_executor.py | 2 + vllm/executor/openvino_executor.py | 2 + vllm/executor/ray_gpu_executor.py | 39 ++++---- vllm/executor/ray_xpu_executor.py | 24 ++--- vllm/executor/tpu_executor.py | 2 + vllm/executor/xpu_executor.py | 2 + .../tokenizer_group/__init__.py | 14 ++- .../tokenizer_group/base_tokenizer_group.py | 7 ++ .../tokenizer_group/ray_tokenizer_group.py | 4 +- .../tokenizer_group/tokenizer_group.py | 6 ++ vllm/worker/worker_base.py | 26 ++++-- 22 files changed, 310 insertions(+), 92 deletions(-) create mode 100644 tests/engine/test_custom_executor.py diff --git a/tests/conftest.py b/tests/conftest.py index 08b8814d983d3..71c4a539c4e8a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -564,6 +564,10 @@ def get_tokenizer_pool_config(tokenizer_group_type): return TokenizerPoolConfig(pool_size=1, pool_type="ray", extra_config={}) + if isinstance(tokenizer_group_type, type): + return TokenizerPoolConfig(pool_size=1, + pool_type=tokenizer_group_type, + extra_config={}) raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}") diff --git a/tests/engine/test_custom_executor.py b/tests/engine/test_custom_executor.py new file mode 100644 index 0000000000000..bff0fc99ed022 --- /dev/null +++ b/tests/engine/test_custom_executor.py @@ -0,0 +1,91 @@ +import asyncio +import os + +import pytest + +from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.llm_engine import LLMEngine +from vllm.executor.gpu_executor import GPUExecutor, GPUExecutorAsync +from vllm.sampling_params import SamplingParams + + +class Mock: + ... + + +class CustomGPUExecutor(GPUExecutor): + + def execute_model(self, *args, **kwargs): + # Drop marker to show that this was ran + with open(".marker", "w"): + ... + return super().execute_model(*args, **kwargs) + + +class CustomGPUExecutorAsync(GPUExecutorAsync): + + async def execute_model_async(self, *args, **kwargs): + with open(".marker", "w"): + ... + return await super().execute_model_async(*args, **kwargs) + + +@pytest.mark.parametrize("model", ["facebook/opt-125m"]) +def test_custom_executor_type_checking(model): + with pytest.raises(ValueError): + engine_args = EngineArgs(model=model, + distributed_executor_backend=Mock) + LLMEngine.from_engine_args(engine_args) + with pytest.raises(ValueError): + engine_args = AsyncEngineArgs(model=model, + distributed_executor_backend=Mock) + AsyncLLMEngine.from_engine_args(engine_args) + with pytest.raises(TypeError): + engine_args = AsyncEngineArgs( + model=model, distributed_executor_backend=CustomGPUExecutor) + AsyncLLMEngine.from_engine_args(engine_args) + + +@pytest.mark.parametrize("model", ["facebook/opt-125m"]) +def test_custom_executor(model, tmpdir): + cwd = os.path.abspath(".") + os.chdir(tmpdir) + try: + assert not os.path.exists(".marker") + + engine_args = EngineArgs( + model=model, distributed_executor_backend=CustomGPUExecutor) + engine = LLMEngine.from_engine_args(engine_args) + sampling_params = SamplingParams(max_tokens=1) + + engine.add_request("0", "foo", sampling_params) + engine.step() + + assert os.path.exists(".marker") + finally: + os.chdir(cwd) + + +@pytest.mark.parametrize("model", ["facebook/opt-125m"]) +def test_custom_executor_async(model, tmpdir): + cwd = os.path.abspath(".") + os.chdir(tmpdir) + try: + assert not os.path.exists(".marker") + + engine_args = AsyncEngineArgs( + model=model, distributed_executor_backend=CustomGPUExecutorAsync) + engine = AsyncLLMEngine.from_engine_args(engine_args) + sampling_params = SamplingParams(max_tokens=1) + + async def t(): + stream = await engine.add_request("0", "foo", sampling_params) + async for x in stream: + ... + + asyncio.run(t()) + + assert os.path.exists(".marker") + finally: + os.chdir(cwd) diff --git a/tests/tokenization/test_tokenizer_group.py b/tests/tokenization/test_tokenizer_group.py index 1b9a590750429..3faaf326f5422 100644 --- a/tests/tokenization/test_tokenizer_group.py +++ b/tests/tokenization/test_tokenizer_group.py @@ -7,17 +7,28 @@ import pytest from transformers import AutoTokenizer, PreTrainedTokenizerBase -from vllm.transformers_utils.tokenizer_group import get_tokenizer_group +from vllm.transformers_utils.tokenizer_group import (TokenizerGroup, + get_tokenizer_group) from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import ( RayTokenizerGroupPool) -from vllm.transformers_utils.tokenizer_group.tokenizer_group import ( - TokenizerGroup) from ..conftest import get_tokenizer_pool_config +class CustomTokenizerGroup(TokenizerGroup): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._i = 0 + + def encode(self, *args, **kwargs): + self._i += 1 + return super().encode(*args, **kwargs) + + @pytest.mark.asyncio -@pytest.mark.parametrize("tokenizer_group_type", [None, "ray"]) +@pytest.mark.parametrize("tokenizer_group_type", + [None, "ray", CustomTokenizerGroup]) async def test_tokenizer_group(tokenizer_group_type): reference_tokenizer = AutoTokenizer.from_pretrained("gpt2") tokenizer_group = get_tokenizer_group( @@ -36,6 +47,8 @@ async def test_tokenizer_group(tokenizer_group_type): PreTrainedTokenizerBase) assert tokenizer_group.get_lora_tokenizer( None) == await tokenizer_group.get_lora_tokenizer_async(None) + if tokenizer_group_type is CustomTokenizerGroup: + assert tokenizer_group._i > 0 @pytest.mark.asyncio diff --git a/vllm/config.py b/vllm/config.py index 9902a152e551a..8dde171576973 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,7 +1,7 @@ import enum import json from dataclasses import dataclass, field, fields -from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Type, Union import torch from transformers import PretrainedConfig @@ -18,7 +18,10 @@ if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup + from vllm.executor.executor_base import ExecutorBase from vllm.model_executor.model_loader.loader import BaseModelLoader + from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( + BaseTokenizerGroup) logger = init_logger(__name__) @@ -527,11 +530,12 @@ class TokenizerPoolConfig: pool type. """ pool_size: int - pool_type: str + pool_type: Union[str, Type["BaseTokenizerGroup"]] extra_config: dict def __post_init__(self): - if self.pool_type not in ("ray", ): + if self.pool_type not in ("ray", ) and not isinstance( + self.pool_type, type): raise ValueError(f"Unknown pool type: {self.pool_type}") if not isinstance(self.extra_config, dict): raise ValueError("extra_config must be a dictionary.") @@ -661,7 +665,8 @@ def __init__( tokenizer_pool_config: Optional[TokenizerPoolConfig] = None, ray_workers_use_nsight: bool = False, placement_group: Optional["PlacementGroup"] = None, - distributed_executor_backend: Optional[str] = None, + distributed_executor_backend: Optional[Union[ + str, Type["ExecutorBase"]]] = None, ) -> None: self.pipeline_parallel_size = pipeline_parallel_size self.tensor_parallel_size = tensor_parallel_size @@ -676,7 +681,7 @@ def __init__( if worker_use_ray: if self.distributed_executor_backend is None: self.distributed_executor_backend = "ray" - elif self.distributed_executor_backend != "ray": + elif not self.use_ray: raise ValueError(f"worker-use-ray can't be used with " f"distributed executor backend " f"'{self.distributed_executor_backend}'.") @@ -711,12 +716,25 @@ def __init__( self._verify_args() self.rank = 0 + @property + def use_ray(self) -> bool: + return self.distributed_executor_backend == "ray" or ( + isinstance(self.distributed_executor_backend, type) + and self.distributed_executor_backend.uses_ray) + def _verify_args(self) -> None: - if self.distributed_executor_backend not in ("ray", "mp", None): + # Lazy import to avoid circular import + from vllm.executor.executor_base import ExecutorBase + + if self.distributed_executor_backend not in ( + "ray", "mp", None) and not (isinstance( + self.distributed_executor_backend, type) and issubclass( + self.distributed_executor_backend, ExecutorBase)): raise ValueError( - "Unrecognized distributed executor backend. Supported values " - "are 'ray' or 'mp'.") - if self.distributed_executor_backend == "ray": + "Unrecognized distributed executor backend " + f"{self.distributed_executor_backend}. Supported " + "values are 'ray', 'mp' or custom ExecutorBase subclass.") + if self.use_ray: from vllm.executor import ray_utils ray_utils.assert_ray_available() if is_hip(): @@ -724,8 +742,7 @@ def _verify_args(self) -> None: logger.info( "Disabled the custom all-reduce kernel because it is not " "supported on AMD GPUs.") - if self.ray_workers_use_nsight and ( - not self.distributed_executor_backend == "ray"): + if self.ray_workers_use_nsight and not self.use_ray: raise ValueError("Unable to use nsight profiling unless workers " "run with Ray.") diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 28ae3448fb495..27a051fcbb2e9 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -2,16 +2,21 @@ import dataclasses import json from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig, SpeculativeConfig, TokenizerPoolConfig) +from vllm.executor.executor_base import ExecutorBase from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.utils import FlexibleArgumentParser +if TYPE_CHECKING: + from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( + BaseTokenizerGroup) + def nullable_str(val: str): if not val or val == "None": @@ -36,7 +41,11 @@ class EngineArgs: seed: int = 0 max_model_len: Optional[int] = None worker_use_ray: bool = False - distributed_executor_backend: Optional[str] = None + # Note: Specifying a custom executor backend by passing a class + # is intended for expert use only. The API may change without + # notice. + distributed_executor_backend: Optional[Union[str, + Type[ExecutorBase]]] = None pipeline_parallel_size: int = 1 tensor_parallel_size: int = 1 max_parallel_loading_workers: Optional[int] = None @@ -62,7 +71,10 @@ class EngineArgs: max_seq_len_to_capture: int = 8192 disable_custom_all_reduce: bool = False tokenizer_pool_size: int = 0 - tokenizer_pool_type: str = "ray" + # Note: Specifying a tokenizer pool by passing a class + # is intended for expert use only. The API may change without + # notice. + tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray" tokenizer_pool_extra_config: Optional[dict] = None enable_lora: bool = False max_loras: int = 1 diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 4df37df794115..ee6a111e77b17 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -7,12 +7,13 @@ from transformers import PreTrainedTokenizer import vllm.envs as envs -from vllm.config import DecodingConfig, ModelConfig +from vllm.config import DecodingConfig, EngineConfig, ModelConfig from vllm.core.scheduler import SchedulerOutputs from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_timeout import asyncio_timeout from vllm.engine.llm_engine import LLMEngine from vllm.engine.metrics import StatLoggerBase +from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.ray_utils import initialize_ray_cluster, ray from vllm.inputs import LLMInputs, PromptInputs from vllm.logger import init_logger @@ -425,25 +426,19 @@ def __init__(self, self._request_tracker: RequestTracker @classmethod - def from_engine_args( - cls, - engine_args: AsyncEngineArgs, - start_engine_loop: bool = True, - usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, - ) -> "AsyncLLMEngine": - """Creates an async LLM engine from the engine arguments.""" - # Create the engine configs. - engine_config = engine_args.create_engine_config() - - if engine_args.engine_use_ray: - from vllm.executor import ray_utils - ray_utils.assert_ray_available() - + def _get_executor_cls( + cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]: distributed_executor_backend = ( engine_config.parallel_config.distributed_executor_backend) - - if engine_config.device_config.device_type == "neuron": + if isinstance(distributed_executor_backend, type): + if not issubclass(distributed_executor_backend, ExecutorAsyncBase): + raise TypeError( + "distributed_executor_backend must be a subclass of " + f"ExecutorAsyncBase. Got {distributed_executor_backend}.") + if distributed_executor_backend.uses_ray: # type: ignore + initialize_ray_cluster(engine_config.parallel_config) + executor_class = distributed_executor_backend + elif engine_config.device_config.device_type == "neuron": from vllm.executor.neuron_executor import NeuronExecutorAsync executor_class = NeuronExecutorAsync elif engine_config.device_config.device_type == "tpu": @@ -482,9 +477,29 @@ def from_engine_args( else: from vllm.executor.gpu_executor import GPUExecutorAsync executor_class = GPUExecutorAsync + return executor_class + + @classmethod + def from_engine_args( + cls, + engine_args: AsyncEngineArgs, + start_engine_loop: bool = True, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + ) -> "AsyncLLMEngine": + """Creates an async LLM engine from the engine arguments.""" + # Create the engine configs. + engine_config = engine_args.create_engine_config() + + if engine_args.engine_use_ray: + from vllm.executor import ray_utils + ray_utils.assert_ray_available() + + executor_class = cls._get_executor_cls(engine_config) + # Create the async LLM engine. engine = cls( - distributed_executor_backend == "ray", + executor_class.uses_ray, engine_args.engine_use_ray, **engine_config.to_dict(), executor_class=executor_class, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 342b82f31d7ae..66b25ad1931b0 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -7,9 +7,9 @@ from transformers import PreTrainedTokenizer import vllm.envs as envs -from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig, - LoRAConfig, ModelConfig, MultiModalConfig, - ObservabilityConfig, ParallelConfig, +from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, + EngineConfig, LoadConfig, LoRAConfig, ModelConfig, + MultiModalConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig, SpeculativeConfig) from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler, @@ -376,19 +376,20 @@ def _initialize_kv_caches(self) -> None: self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks) @classmethod - def from_engine_args( - cls, - engine_args: EngineArgs, - usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, - ) -> "LLMEngine": - """Creates an LLM engine from the engine arguments.""" - # Create the engine configs. - engine_config = engine_args.create_engine_config() + def _get_executor_cls(cls, + engine_config: EngineConfig) -> Type[ExecutorBase]: distributed_executor_backend = ( engine_config.parallel_config.distributed_executor_backend) # Initialize the cluster and specify the executor class. - if engine_config.device_config.device_type == "neuron": + if isinstance(distributed_executor_backend, type): + if not issubclass(distributed_executor_backend, ExecutorBase): + raise TypeError( + "distributed_executor_backend must be a subclass of " + f"ExecutorBase. Got {distributed_executor_backend}.") + if distributed_executor_backend.uses_ray: # type: ignore + initialize_ray_cluster(engine_config.parallel_config) + executor_class = distributed_executor_backend + elif engine_config.device_config.device_type == "neuron": from vllm.executor.neuron_executor import NeuronExecutor executor_class = NeuronExecutor elif engine_config.device_config.device_type == "tpu": @@ -422,6 +423,19 @@ def from_engine_args( else: from vllm.executor.gpu_executor import GPUExecutor executor_class = GPUExecutor + return executor_class + + @classmethod + def from_engine_args( + cls, + engine_args: EngineArgs, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + ) -> "LLMEngine": + """Creates an LLM engine from the engine arguments.""" + # Create the engine configs. + engine_config = engine_args.create_engine_config() + executor_class = cls._get_executor_cls(engine_config) # Create the LLM engine. engine = cls( **engine_config.to_dict(), diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index d3b60e3ff4260..23e429dac7232 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -17,6 +17,8 @@ class CPUExecutor(ExecutorBase): + uses_ray: bool = False + def _init_executor(self) -> None: assert self.device_config.device_type == "cpu" assert self.lora_config is None, "cpu backend doesn't support LoRA" diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 713f3868f66ae..a848bc70941c1 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -18,6 +18,8 @@ class ExecutorBase(ABC): that can execute the model on multiple devices. """ + uses_ray: bool # whether the executor uses Ray for orchestration. + def __init__( self, model_config: ModelConfig, diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 59cef83f2cdaa..3e77af0e20323 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -23,6 +23,8 @@ def create_worker(worker_module_name, worker_class_name, **kwargs): class GPUExecutor(ExecutorBase): + uses_ray: bool = False + def _init_executor(self) -> None: """Initialize the worker and load the model. """ diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index 08b417a45d046..9811fc2a55199 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -25,6 +25,8 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor): """Python multiprocessing-based multi-GPU executor""" + uses_ray: bool = False + def _init_executor(self) -> None: # Create the parallel GPU workers. world_size = self.parallel_config.world_size diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index 6b2cb3e2403f2..5d4c4f497f470 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -11,6 +11,8 @@ class NeuronExecutor(ExecutorBase): + uses_ray: bool = False + def _init_executor(self) -> None: assert (self.lora_config is None), "LoRA is not supported for Neuron backend." diff --git a/vllm/executor/openvino_executor.py b/vllm/executor/openvino_executor.py index 1ef37785b6d59..c52a1c9839d7b 100644 --- a/vllm/executor/openvino_executor.py +++ b/vllm/executor/openvino_executor.py @@ -18,6 +18,8 @@ class OpenVINOExecutor(ExecutorBase): + uses_ray: bool = False + def _init_executor(self) -> None: assert self.device_config.device_type == "openvino" assert self.lora_config is None, "OpenVINO backend doesn't support LoRA" diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 0d44362b91672..e4aaeaa24c1bc 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -26,6 +26,8 @@ class RayGPUExecutor(DistributedGPUExecutor): + uses_ray: bool = True + def _init_executor(self) -> None: # If the env var is set, it uses the Ray's compiled DAG API # which optimizes the control plane overhead. @@ -47,7 +49,7 @@ def _init_executor(self) -> None: "VLLM_USE_RAY_SPMD_WORKER=1 requires " "VLLM_USE_RAY_COMPILED_DAG=1") - assert self.parallel_config.distributed_executor_backend == "ray" + assert self.uses_ray placement_group = self.parallel_config.placement_group # Disable Ray usage stats collection. @@ -75,6 +77,20 @@ def _configure_ray_workers_use_nsight(self, return ray_remote_kwargs + def _get_worker_wrapper_args(self) -> Dict[str, Any]: + if self.speculative_config is not None: + worker_module_name = "vllm.spec_decode.spec_decode_worker" + worker_class_name = "create_spec_worker" + else: + worker_module_name = "vllm.worker.worker" + worker_class_name = "Worker" + + return dict( + worker_module_name=worker_module_name, + worker_class_name=worker_class_name, + trust_remote_code=self.model_config.trust_remote_code, + ) + def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs): if (self.parallel_config.tensor_parallel_size == 1 @@ -97,6 +113,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # Create the workers. driver_ip = get_ip() + worker_wrapper_kwargs = self._get_worker_wrapper_args() for bundle_id, bundle in enumerate(placement_group.bundle_specs): if not bundle.get("GPU", 0): continue @@ -106,23 +123,12 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", placement_group_bundle_index=bundle_id, ) - if self.speculative_config is not None: - worker_module_name = "vllm.spec_decode.spec_decode_worker" - worker_class_name = "create_spec_worker" - else: - worker_module_name = "vllm.worker.worker" - worker_class_name = "Worker" - worker = ray.remote( num_cpus=0, num_gpus=num_gpus, scheduling_strategy=scheduling_strategy, **ray_remote_kwargs, - )(RayWorkerWrapper).remote( - worker_module_name=worker_module_name, - worker_class_name=worker_class_name, - trust_remote_code=self.model_config.trust_remote_code, - ) + )(RayWorkerWrapper).remote(**worker_wrapper_kwargs) if self.use_ray_spmd_worker: self.workers.append(worker) @@ -133,10 +139,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # as the resource holder for the driver process. self.driver_dummy_worker = worker self.driver_worker = RayWorkerWrapper( - worker_module_name=worker_module_name, - worker_class_name=worker_class_name, - trust_remote_code=self.model_config.trust_remote_code, - ) + **worker_wrapper_kwargs) else: # Else, added to the list of workers. self.workers.append(worker) @@ -378,7 +381,7 @@ def _compiled_ray_dag(self, enable_asyncio: bool): f"required, but found {current_version}") from ray.dag import InputNode, MultiOutputNode - assert self.parallel_config.distributed_executor_backend == "ray" + assert self.parallel_config.use_ray # Right now, compiled DAG requires at least 1 arg. We send # a dummy value for now. It will be fixed soon. diff --git a/vllm/executor/ray_xpu_executor.py b/vllm/executor/ray_xpu_executor.py index 2a93616ced06c..bdd8ba9032766 100644 --- a/vllm/executor/ray_xpu_executor.py +++ b/vllm/executor/ray_xpu_executor.py @@ -35,6 +35,8 @@ class RayXPUExecutor(DistributedGPUExecutor): + uses_ray: bool = True + def __init__( self, model_config: ModelConfig, @@ -107,6 +109,13 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: return num_gpu_blocks, num_cpu_blocks + def _get_worker_wrapper_args(self) -> Dict[str, Any]: + return dict( + worker_module_name="vllm.worker.xpu_worker", + worker_class_name="XPUWorker", + trust_remote_code=self.model_config.trust_remote_code, + ) + def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs): if self.parallel_config.tensor_parallel_size == 1: @@ -124,6 +133,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # Create the workers. driver_ip = get_ip() + worker_wrapper_kwargs = self._get_worker_wrapper_args() for bundle_id, bundle in enumerate(placement_group.bundle_specs): if not bundle.get("GPU", 0): continue @@ -137,22 +147,14 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", num_gpus=num_gpus, scheduling_strategy=scheduling_strategy, **ray_remote_kwargs, - )(RayWorkerWrapper).remote( - worker_module_name="vllm.worker.xpu_worker", - worker_class_name="XPUWorker", - trust_remote_code=self.model_config.trust_remote_code, - ) + )(RayWorkerWrapper).remote(**worker_wrapper_kwargs) worker_ip = ray.get(worker.get_node_ip.remote()) if worker_ip == driver_ip and self.driver_dummy_worker is None: # If the worker is on the same node as the driver, we use it # as the resource holder for the driver process. self.driver_dummy_worker = worker - self.driver_worker = RayWorkerWrapper( - worker_module_name="vllm.worker.xpu_worker", - worker_class_name="XPUWorker", - trust_remote_code=self.model_config.trust_remote_code, - ) + self.driver_worker = RayWorkerWrapper(**worker_wrapper_kwargs) else: # Else, added to the list of workers. self.workers.append(worker) @@ -337,7 +339,7 @@ def _compiled_ray_dag(self, enable_asyncio: bool): f"required, but found {current_version}") from ray.dag import InputNode, MultiOutputNode - assert self.parallel_config.distributed_executor_backend == "ray" + assert self.parallel_config.use_ray # Right now, compiled DAG requires at least 1 arg. We send # a dummy value for now. It will be fixed soon. diff --git a/vllm/executor/tpu_executor.py b/vllm/executor/tpu_executor.py index d906a6cc39dd7..1b5bb5c755ef2 100644 --- a/vllm/executor/tpu_executor.py +++ b/vllm/executor/tpu_executor.py @@ -14,6 +14,8 @@ class TPUExecutor(ExecutorBase): + uses_ray: bool = False + def _init_executor(self) -> None: assert not self.scheduler_config.chunked_prefill_enabled, ( "Chunked prefill is not yet supported for TPU backend") diff --git a/vllm/executor/xpu_executor.py b/vllm/executor/xpu_executor.py index f6550cce9ab1a..9feae6a05ba9b 100644 --- a/vllm/executor/xpu_executor.py +++ b/vllm/executor/xpu_executor.py @@ -18,6 +18,8 @@ class XPUExecutor(GPUExecutor): + uses_ray: bool = False + def __init__( self, model_config: ModelConfig, diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py index 0195c40c27f60..9f54f5409b181 100644 --- a/vllm/transformers_utils/tokenizer_group/__init__.py +++ b/vllm/transformers_utils/tokenizer_group/__init__.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Type from vllm.config import TokenizerPoolConfig from vllm.executor.ray_utils import ray @@ -16,18 +16,22 @@ def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig], **init_kwargs) -> BaseTokenizerGroup: + tokenizer_cls: Type[BaseTokenizerGroup] if tokenizer_pool_config is None: - return TokenizerGroup(**init_kwargs) - if tokenizer_pool_config.pool_type == "ray": + tokenizer_cls = TokenizerGroup + elif isinstance(tokenizer_pool_config.pool_type, type) and issubclass( + tokenizer_pool_config.pool_type, BaseTokenizerGroup): + tokenizer_cls = tokenizer_pool_config.pool_type + elif tokenizer_pool_config.pool_type == "ray": if RayTokenizerGroupPool is None: raise ImportError( "RayTokenizerGroupPool is not available. Please install " "the ray package to use the Ray tokenizer group pool.") - return RayTokenizerGroupPool.from_config(tokenizer_pool_config, - **init_kwargs) + tokenizer_cls = RayTokenizerGroupPool else: raise ValueError( f"Unknown pool type: {tokenizer_pool_config.pool_type}") + return tokenizer_cls.from_config(tokenizer_pool_config, **init_kwargs) __all__ = ["get_tokenizer_group", "BaseTokenizerGroup"] diff --git a/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py index 18fbd894f1c0e..9682db6966ddf 100644 --- a/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py @@ -3,12 +3,19 @@ from transformers import PreTrainedTokenizer +from vllm.config import TokenizerPoolConfig from vllm.lora.request import LoRARequest class BaseTokenizerGroup(ABC): """A group of tokenizers that can be used for LoRA adapters.""" + @classmethod + @abstractmethod + def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig], + **init_kwargs) -> "BaseTokenizerGroup": + pass + @abstractmethod def ping(self) -> bool: """Check if the tokenizer group is alive.""" diff --git a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py index 799ca7d3f15c0..32384398a4c12 100644 --- a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py @@ -29,8 +29,10 @@ class RayTokenizerGroupPool(BaseTokenizerGroup): _worker_cls = TokenizerGroup @classmethod - def from_config(cls, tokenizer_pool_config: TokenizerPoolConfig, + def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig], **init_kwargs) -> "RayTokenizerGroupPool": + if not tokenizer_pool_config: + raise ValueError("tokenizer_pool_config must not be None.") ray_actor_options = (tokenizer_pool_config.extra_config or { "num_cpus": 0 }) diff --git a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py index 9614f01d2b955..74c041f13bad9 100644 --- a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py @@ -2,6 +2,7 @@ from transformers import PreTrainedTokenizer +from vllm.config import TokenizerPoolConfig from vllm.lora.request import LoRARequest from vllm.transformers_utils.tokenizer import (get_lora_tokenizer, get_lora_tokenizer_async, @@ -24,6 +25,11 @@ def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, self.lora_tokenizers = LRUCache[PreTrainedTokenizer]( capacity=max_num_seqs) if enable_lora else None + @classmethod + def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig], + **init_kwargs) -> "TokenizerGroup": + return cls(**init_kwargs) + def ping(self) -> bool: """Check if the tokenizer group is alive.""" return True diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index d79b9f24e656b..8e5c0ededba15 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -2,7 +2,7 @@ import importlib import os from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union import torch @@ -315,14 +315,23 @@ class WorkerWrapperBase: We first instantiate the WorkerWrapper, which remembers the worker module and class name. Then, when we call `update_environment_variables`, and the real initialization happens in `init_worker`. + + If worker_class_fn is specified, it will be executed to get the worker + class. + Otherwise, the worker class will be obtained by dynamically importing it + using worker_module_name and worker_class_name. """ - def __init__(self, - worker_module_name: str, - worker_class_name: str, - trust_remote_code: bool = False) -> None: + def __init__( + self, + worker_module_name: str, + worker_class_name: str, + trust_remote_code: bool = False, + worker_class_fn: Optional[Callable[[], + Type[WorkerBase]]] = None) -> None: self.worker_module_name = worker_module_name self.worker_class_name = worker_class_name + self.worker_class_fn = worker_class_fn self.worker: Optional[WorkerBase] = None if trust_remote_code: # note: lazy import to avoid importing torch before initializing @@ -348,8 +357,11 @@ def init_worker(self, *args, **kwargs): # see https://github.com/NVIDIA/nccl/issues/1234 os.environ['NCCL_CUMEM_ENABLE'] = '0' - mod = importlib.import_module(self.worker_module_name) - worker_class = getattr(mod, self.worker_class_name) + if self.worker_class_fn: + worker_class = self.worker_class_fn() + else: + mod = importlib.import_module(self.worker_module_name) + worker_class = getattr(mod, self.worker_class_name) self.worker = worker_class(*args, **kwargs) assert self.worker is not None