Skip to content

Commit

Permalink
[Core] Allow specifying custom Executor (vllm-project#6557)
Browse files Browse the repository at this point in the history
Signed-off-by: Alvant <[email protected]>
  • Loading branch information
Yard1 authored and Alvant committed Oct 26, 2024
1 parent abcc9e1 commit 23c8419
Show file tree
Hide file tree
Showing 22 changed files with 310 additions and 92 deletions.
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,10 @@ def get_tokenizer_pool_config(tokenizer_group_type):
return TokenizerPoolConfig(pool_size=1,
pool_type="ray",
extra_config={})
if isinstance(tokenizer_group_type, type):
return TokenizerPoolConfig(pool_size=1,
pool_type=tokenizer_group_type,
extra_config={})
raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")


Expand Down
91 changes: 91 additions & 0 deletions tests/engine/test_custom_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import asyncio
import os

import pytest

from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine
from vllm.executor.gpu_executor import GPUExecutor, GPUExecutorAsync
from vllm.sampling_params import SamplingParams


class Mock:
...


class CustomGPUExecutor(GPUExecutor):

def execute_model(self, *args, **kwargs):
# Drop marker to show that this was ran
with open(".marker", "w"):
...
return super().execute_model(*args, **kwargs)


class CustomGPUExecutorAsync(GPUExecutorAsync):

async def execute_model_async(self, *args, **kwargs):
with open(".marker", "w"):
...
return await super().execute_model_async(*args, **kwargs)


@pytest.mark.parametrize("model", ["facebook/opt-125m"])
def test_custom_executor_type_checking(model):
with pytest.raises(ValueError):
engine_args = EngineArgs(model=model,
distributed_executor_backend=Mock)
LLMEngine.from_engine_args(engine_args)
with pytest.raises(ValueError):
engine_args = AsyncEngineArgs(model=model,
distributed_executor_backend=Mock)
AsyncLLMEngine.from_engine_args(engine_args)
with pytest.raises(TypeError):
engine_args = AsyncEngineArgs(
model=model, distributed_executor_backend=CustomGPUExecutor)
AsyncLLMEngine.from_engine_args(engine_args)


@pytest.mark.parametrize("model", ["facebook/opt-125m"])
def test_custom_executor(model, tmpdir):
cwd = os.path.abspath(".")
os.chdir(tmpdir)
try:
assert not os.path.exists(".marker")

engine_args = EngineArgs(
model=model, distributed_executor_backend=CustomGPUExecutor)
engine = LLMEngine.from_engine_args(engine_args)
sampling_params = SamplingParams(max_tokens=1)

engine.add_request("0", "foo", sampling_params)
engine.step()

assert os.path.exists(".marker")
finally:
os.chdir(cwd)


@pytest.mark.parametrize("model", ["facebook/opt-125m"])
def test_custom_executor_async(model, tmpdir):
cwd = os.path.abspath(".")
os.chdir(tmpdir)
try:
assert not os.path.exists(".marker")

engine_args = AsyncEngineArgs(
model=model, distributed_executor_backend=CustomGPUExecutorAsync)
engine = AsyncLLMEngine.from_engine_args(engine_args)
sampling_params = SamplingParams(max_tokens=1)

async def t():
stream = await engine.add_request("0", "foo", sampling_params)
async for x in stream:
...

asyncio.run(t())

assert os.path.exists(".marker")
finally:
os.chdir(cwd)
21 changes: 17 additions & 4 deletions tests/tokenization/test_tokenizer_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,28 @@
import pytest
from transformers import AutoTokenizer, PreTrainedTokenizerBase

from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
from vllm.transformers_utils.tokenizer_group import (TokenizerGroup,
get_tokenizer_group)
from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import (
RayTokenizerGroupPool)
from vllm.transformers_utils.tokenizer_group.tokenizer_group import (
TokenizerGroup)

from ..conftest import get_tokenizer_pool_config


class CustomTokenizerGroup(TokenizerGroup):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._i = 0

def encode(self, *args, **kwargs):
self._i += 1
return super().encode(*args, **kwargs)


@pytest.mark.asyncio
@pytest.mark.parametrize("tokenizer_group_type", [None, "ray"])
@pytest.mark.parametrize("tokenizer_group_type",
[None, "ray", CustomTokenizerGroup])
async def test_tokenizer_group(tokenizer_group_type):
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer_group = get_tokenizer_group(
Expand All @@ -36,6 +47,8 @@ async def test_tokenizer_group(tokenizer_group_type):
PreTrainedTokenizerBase)
assert tokenizer_group.get_lora_tokenizer(
None) == await tokenizer_group.get_lora_tokenizer_async(None)
if tokenizer_group_type is CustomTokenizerGroup:
assert tokenizer_group._i > 0


@pytest.mark.asyncio
Expand Down
39 changes: 28 additions & 11 deletions vllm/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import enum
import json
from dataclasses import dataclass, field, fields
from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Type, Union

import torch
from transformers import PretrainedConfig
Expand All @@ -18,7 +18,10 @@
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup

from vllm.executor.executor_base import ExecutorBase
from vllm.model_executor.model_loader.loader import BaseModelLoader
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)

logger = init_logger(__name__)

Expand Down Expand Up @@ -527,11 +530,12 @@ class TokenizerPoolConfig:
pool type.
"""
pool_size: int
pool_type: str
pool_type: Union[str, Type["BaseTokenizerGroup"]]
extra_config: dict

def __post_init__(self):
if self.pool_type not in ("ray", ):
if self.pool_type not in ("ray", ) and not isinstance(
self.pool_type, type):
raise ValueError(f"Unknown pool type: {self.pool_type}")
if not isinstance(self.extra_config, dict):
raise ValueError("extra_config must be a dictionary.")
Expand Down Expand Up @@ -661,7 +665,8 @@ def __init__(
tokenizer_pool_config: Optional[TokenizerPoolConfig] = None,
ray_workers_use_nsight: bool = False,
placement_group: Optional["PlacementGroup"] = None,
distributed_executor_backend: Optional[str] = None,
distributed_executor_backend: Optional[Union[
str, Type["ExecutorBase"]]] = None,
) -> None:
self.pipeline_parallel_size = pipeline_parallel_size
self.tensor_parallel_size = tensor_parallel_size
Expand All @@ -676,7 +681,7 @@ def __init__(
if worker_use_ray:
if self.distributed_executor_backend is None:
self.distributed_executor_backend = "ray"
elif self.distributed_executor_backend != "ray":
elif not self.use_ray:
raise ValueError(f"worker-use-ray can't be used with "
f"distributed executor backend "
f"'{self.distributed_executor_backend}'.")
Expand Down Expand Up @@ -711,21 +716,33 @@ def __init__(
self._verify_args()
self.rank = 0

@property
def use_ray(self) -> bool:
return self.distributed_executor_backend == "ray" or (
isinstance(self.distributed_executor_backend, type)
and self.distributed_executor_backend.uses_ray)

def _verify_args(self) -> None:
if self.distributed_executor_backend not in ("ray", "mp", None):
# Lazy import to avoid circular import
from vllm.executor.executor_base import ExecutorBase

if self.distributed_executor_backend not in (
"ray", "mp", None) and not (isinstance(
self.distributed_executor_backend, type) and issubclass(
self.distributed_executor_backend, ExecutorBase)):
raise ValueError(
"Unrecognized distributed executor backend. Supported values "
"are 'ray' or 'mp'.")
if self.distributed_executor_backend == "ray":
"Unrecognized distributed executor backend "
f"{self.distributed_executor_backend}. Supported "
"values are 'ray', 'mp' or custom ExecutorBase subclass.")
if self.use_ray:
from vllm.executor import ray_utils
ray_utils.assert_ray_available()
if is_hip():
self.disable_custom_all_reduce = True
logger.info(
"Disabled the custom all-reduce kernel because it is not "
"supported on AMD GPUs.")
if self.ray_workers_use_nsight and (
not self.distributed_executor_backend == "ray"):
if self.ray_workers_use_nsight and not self.use_ray:
raise ValueError("Unable to use nsight profiling unless workers "
"run with Ray.")

Expand Down
18 changes: 15 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,21 @@
import dataclasses
import json
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union

from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
MultiModalConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig, TokenizerPoolConfig)
from vllm.executor.executor_base import ExecutorBase
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser

if TYPE_CHECKING:
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)


def nullable_str(val: str):
if not val or val == "None":
Expand All @@ -36,7 +41,11 @@ class EngineArgs:
seed: int = 0
max_model_len: Optional[int] = None
worker_use_ray: bool = False
distributed_executor_backend: Optional[str] = None
# Note: Specifying a custom executor backend by passing a class
# is intended for expert use only. The API may change without
# notice.
distributed_executor_backend: Optional[Union[str,
Type[ExecutorBase]]] = None
pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1
max_parallel_loading_workers: Optional[int] = None
Expand All @@ -62,7 +71,10 @@ class EngineArgs:
max_seq_len_to_capture: int = 8192
disable_custom_all_reduce: bool = False
tokenizer_pool_size: int = 0
tokenizer_pool_type: str = "ray"
# Note: Specifying a tokenizer pool by passing a class
# is intended for expert use only. The API may change without
# notice.
tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
tokenizer_pool_extra_config: Optional[dict] = None
enable_lora: bool = False
max_loras: int = 1
Expand Down
53 changes: 34 additions & 19 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
from transformers import PreTrainedTokenizer

import vllm.envs as envs
from vllm.config import DecodingConfig, ModelConfig
from vllm.config import DecodingConfig, EngineConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.metrics import StatLoggerBase
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.ray_utils import initialize_ray_cluster, ray
from vllm.inputs import LLMInputs, PromptInputs
from vllm.logger import init_logger
Expand Down Expand Up @@ -425,25 +426,19 @@ def __init__(self,
self._request_tracker: RequestTracker

@classmethod
def from_engine_args(
cls,
engine_args: AsyncEngineArgs,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
engine_config = engine_args.create_engine_config()

if engine_args.engine_use_ray:
from vllm.executor import ray_utils
ray_utils.assert_ray_available()

def _get_executor_cls(
cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]:
distributed_executor_backend = (
engine_config.parallel_config.distributed_executor_backend)

if engine_config.device_config.device_type == "neuron":
if isinstance(distributed_executor_backend, type):
if not issubclass(distributed_executor_backend, ExecutorAsyncBase):
raise TypeError(
"distributed_executor_backend must be a subclass of "
f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
if distributed_executor_backend.uses_ray: # type: ignore
initialize_ray_cluster(engine_config.parallel_config)
executor_class = distributed_executor_backend
elif engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutorAsync
executor_class = NeuronExecutorAsync
elif engine_config.device_config.device_type == "tpu":
Expand Down Expand Up @@ -482,9 +477,29 @@ def from_engine_args(
else:
from vllm.executor.gpu_executor import GPUExecutorAsync
executor_class = GPUExecutorAsync
return executor_class

@classmethod
def from_engine_args(
cls,
engine_args: AsyncEngineArgs,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
engine_config = engine_args.create_engine_config()

if engine_args.engine_use_ray:
from vllm.executor import ray_utils
ray_utils.assert_ray_available()

executor_class = cls._get_executor_cls(engine_config)

# Create the async LLM engine.
engine = cls(
distributed_executor_backend == "ray",
executor_class.uses_ray,
engine_args.engine_use_ray,
**engine_config.to_dict(),
executor_class=executor_class,
Expand Down
Loading

0 comments on commit 23c8419

Please sign in to comment.