From 1b62745b1d00153c5e99879edaf0c2d7ceb4e2c6 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 7 Dec 2024 09:33:45 -0800 Subject: [PATCH 01/19] [core][executor] simplify instance id (#10976) Signed-off-by: youkaichao --- vllm/config.py | 7 ++++++- vllm/envs.py | 6 ------ vllm/executor/cpu_executor.py | 6 +----- vllm/executor/multiproc_gpu_executor.py | 5 +---- vllm/executor/ray_gpu_executor.py | 7 +------ vllm/executor/ray_hpu_executor.py | 7 +------ vllm/executor/ray_tpu_executor.py | 6 +----- vllm/executor/ray_xpu_executor.py | 6 +----- vllm/utils.py | 25 +++++++++---------------- vllm/worker/worker_base.py | 2 +- 10 files changed, 22 insertions(+), 55 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index db7046ab2c22d..d1c4f995ad015 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -27,7 +27,8 @@ get_hf_text_config, get_pooling_config, get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope) from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, - print_warning_once, resolve_obj_by_qualname) + print_warning_once, random_uuid, + resolve_obj_by_qualname) if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup @@ -2408,6 +2409,7 @@ class VllmConfig: init=True) # type: ignore kv_transfer_config: KVTransferConfig = field(default=None, init=True) # type: ignore + instance_id: str = "" @staticmethod def get_graph_batch_size(batch_size: int) -> int: @@ -2573,6 +2575,9 @@ def __post_init__(self): current_platform.check_and_update_config(self) + if not self.instance_id: + self.instance_id = random_uuid()[:5] + def __str__(self): return ("model=%r, speculative_config=%r, tokenizer=%r, " "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " diff --git a/vllm/envs.py b/vllm/envs.py index 28797ac1e4af2..ab12a7b48dc53 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -8,7 +8,6 @@ VLLM_RPC_BASE_PATH: str = tempfile.gettempdir() VLLM_USE_MODELSCOPE: bool = False VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60 - VLLM_INSTANCE_ID: Optional[str] = None VLLM_NCCL_SO_PATH: Optional[str] = None LD_LIBRARY_PATH: Optional[str] = None VLLM_USE_TRITON_FLASH_ATTN: bool = False @@ -175,11 +174,6 @@ def get_default_config_root(): "VLLM_USE_MODELSCOPE": lambda: os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true", - # Instance id represents an instance of the VLLM. All processes in the same - # instance should have the same instance id. - "VLLM_INSTANCE_ID": - lambda: os.environ.get("VLLM_INSTANCE_ID", None), - # Interval in seconds to log a warning message when the ring buffer is full "VLLM_RINGBUFFER_WARNING_INTERVAL": lambda: int(os.environ.get("VLLM_RINGBUFFER_WARNING_INTERVAL", "60")), diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 6b4cb5a9a1d61..2816b5c5c1f88 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -10,8 +10,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest -from vllm.utils import (get_distributed_init_method, get_open_port, - get_vllm_instance_id, make_async) +from vllm.utils import get_distributed_init_method, get_open_port, make_async from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -31,9 +30,6 @@ def _init_executor(self) -> None: # Environment variables for CPU executor # - # Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers - os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id() - # Disable torch async compiling which won't work with daemonic processes os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index a6c05a71d2b6f..c450209f0eb91 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -16,7 +16,7 @@ from vllm.triton_utils.importing import HAS_TRITON from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless, cuda_is_initialized, get_distributed_init_method, - get_open_port, get_vllm_instance_id, make_async, + get_open_port, make_async, update_environment_variables) if HAS_TRITON: @@ -37,9 +37,6 @@ def _init_executor(self) -> None: world_size = self.parallel_config.world_size tensor_parallel_size = self.parallel_config.tensor_parallel_size - # Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers - os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id() - # Disable torch async compiling which won't work with daemonic processes os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 6542b18ae70b1..6554cda6b637b 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -15,8 +15,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import ExecuteModelRequest from vllm.utils import (_run_task_with_lock, get_distributed_init_method, - get_ip, get_open_port, get_vllm_instance_id, - make_async) + get_ip, get_open_port, make_async) if ray is not None: from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy @@ -220,14 +219,10 @@ def sort_by_driver_then_worker_ip(worker): " environment variable, make sure it is unique for" " each node.") - VLLM_INSTANCE_ID = get_vllm_instance_id() - # Set environment variables for the driver and workers. all_args_to_update_environment_variables = [({ "CUDA_VISIBLE_DEVICES": ",".join(map(str, node_gpus[node_id])), - "VLLM_INSTANCE_ID": - VLLM_INSTANCE_ID, "VLLM_TRACE_FUNCTION": str(envs.VLLM_TRACE_FUNCTION), **({ diff --git a/vllm/executor/ray_hpu_executor.py b/vllm/executor/ray_hpu_executor.py index a74328e5aa272..91c84d9214a60 100644 --- a/vllm/executor/ray_hpu_executor.py +++ b/vllm/executor/ray_hpu_executor.py @@ -15,8 +15,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import ExecuteModelRequest from vllm.utils import (_run_task_with_lock, get_distributed_init_method, - get_ip, get_open_port, get_vllm_instance_id, - make_async) + get_ip, get_open_port, make_async) if ray is not None: from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy @@ -196,12 +195,8 @@ def sort_by_driver_then_worker_ip(worker): "environment variable, make sure it is unique for" " each node.") - VLLM_INSTANCE_ID = get_vllm_instance_id() - # Set environment variables for the driver and workers. all_args_to_update_environment_variables = [({ - "VLLM_INSTANCE_ID": - VLLM_INSTANCE_ID, "VLLM_TRACE_FUNCTION": str(envs.VLLM_TRACE_FUNCTION), }, ) for (node_id, _) in worker_node_and_gpu_ids] diff --git a/vllm/executor/ray_tpu_executor.py b/vllm/executor/ray_tpu_executor.py index c227b5e283c68..3ee59397bf4c9 100644 --- a/vllm/executor/ray_tpu_executor.py +++ b/vllm/executor/ray_tpu_executor.py @@ -13,7 +13,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import ExecuteModelRequest from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, - get_vllm_instance_id, make_async) + make_async) if ray is not None: from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy @@ -144,12 +144,8 @@ def sort_by_driver_then_worker_ip(worker): for i, (node_id, _) in enumerate(worker_node_and_gpu_ids): node_workers[node_id].append(i) - VLLM_INSTANCE_ID = get_vllm_instance_id() - # Set environment variables for the driver and workers. all_args_to_update_environment_variables = [({ - "VLLM_INSTANCE_ID": - VLLM_INSTANCE_ID, "VLLM_TRACE_FUNCTION": str(envs.VLLM_TRACE_FUNCTION), }, ) for _ in worker_node_and_gpu_ids] diff --git a/vllm/executor/ray_xpu_executor.py b/vllm/executor/ray_xpu_executor.py index 2b1cdc09b0a9f..61f5d6a65e999 100644 --- a/vllm/executor/ray_xpu_executor.py +++ b/vllm/executor/ray_xpu_executor.py @@ -5,7 +5,7 @@ from vllm.executor.ray_gpu_executor import RayGPUExecutor, RayGPUExecutorAsync from vllm.executor.xpu_executor import XPUExecutor from vllm.logger import init_logger -from vllm.utils import get_vllm_instance_id, make_async +from vllm.utils import make_async logger = init_logger(__name__) @@ -17,12 +17,8 @@ def _get_env_vars_to_be_updated(self): worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", use_dummy_driver=True) - VLLM_INSTANCE_ID = get_vllm_instance_id() - # Set environment variables for the driver and workers. all_args_to_update_environment_variables = [({ - "VLLM_INSTANCE_ID": - VLLM_INSTANCE_ID, "VLLM_TRACE_FUNCTION": str(envs.VLLM_TRACE_FUNCTION), }, ) for (_, _) in worker_node_and_gpu_ids] diff --git a/vllm/utils.py b/vllm/utils.py index 6cee4847e57b4..1f19d9eacd16d 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -24,9 +24,9 @@ from collections.abc import Iterable, Mapping from functools import lru_cache, partial, wraps from platform import uname -from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic, - Hashable, List, Literal, Optional, OrderedDict, Set, Tuple, - Type, TypeVar, Union, overload) +from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, + Dict, Generic, Hashable, List, Literal, Optional, + OrderedDict, Set, Tuple, Type, TypeVar, Union, overload) from uuid import uuid4 import numpy as np @@ -43,6 +43,9 @@ from vllm.logger import enable_trace_function_call, init_logger from vllm.platforms import current_platform +if TYPE_CHECKING: + from vllm.config import VllmConfig + logger = init_logger(__name__) # Exception strings for non-implemented encoder/decoder scenarios @@ -335,17 +338,6 @@ def random_uuid() -> str: return str(uuid.uuid4().hex) -@lru_cache(maxsize=None) -def get_vllm_instance_id() -> str: - """ - If the environment variable VLLM_INSTANCE_ID is set, return it. - Otherwise, return a random UUID. - Instance id represents an instance of the VLLM. All processes in the same - instance should have the same instance id. - """ - return envs.VLLM_INSTANCE_ID or f"vllm-instance-{random_uuid()}" - - @lru_cache(maxsize=None) def in_wsl() -> bool: # Reference: https://github.com/microsoft/WSL/issues/4071 @@ -997,7 +989,7 @@ def find_nccl_library() -> str: return so_file -def enable_trace_function_call_for_thread() -> None: +def enable_trace_function_call_for_thread(vllm_config: "VllmConfig") -> None: """Set up function tracing for the current thread, if enabled via the VLLM_TRACE_FUNCTION environment variable """ @@ -1009,7 +1001,8 @@ def enable_trace_function_call_for_thread() -> None: filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}" f"_thread_{threading.get_ident()}_" f"at_{datetime.datetime.now()}.log").replace(" ", "_") - log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(), + log_path = os.path.join(tmp_dir, "vllm", + f"vllm-instance-{vllm_config.instance_id}", filename) os.makedirs(os.path.dirname(log_path), exist_ok=True) enable_trace_function_call(log_path) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 7c0bc5a678956..6d00102e0a324 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -439,7 +439,7 @@ def init_worker(self, *args, **kwargs): Here we inject some common logic before initializing the worker. Arguments are passed to the worker class constructor. """ - enable_trace_function_call_for_thread() + enable_trace_function_call_for_thread(self.vllm_config) # see https://github.com/NVIDIA/nccl/issues/1234 os.environ['NCCL_CUMEM_ENABLE'] = '0' From 7be15d9356a10c6ae3537565548e4f8bf46f35dd Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 7 Dec 2024 12:06:08 -0800 Subject: [PATCH 02/19] [core][misc] remove use_dummy driver for _run_workers (#10920) Signed-off-by: youkaichao --- vllm/executor/ray_gpu_executor.py | 27 ++++++++++++--------------- vllm/executor/ray_hpu_executor.py | 28 ++++++++++++---------------- vllm/executor/ray_tpu_executor.py | 21 ++++++++++----------- vllm/executor/ray_xpu_executor.py | 11 +++++++++-- 4 files changed, 43 insertions(+), 44 deletions(-) diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 6554cda6b637b..4263fb27265f6 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -188,8 +188,14 @@ def sort_by_driver_then_worker_ip(worker): self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip) # Get the set of GPU IDs used on each node. - worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", - use_dummy_driver=True) + worker_node_and_gpu_ids = [] + for worker in [self.driver_dummy_worker] + self.workers: + if worker is None: + # driver_dummy_worker can be None when using ray spmd worker. + continue + worker_node_and_gpu_ids.append( + ray.get(worker.get_node_and_gpu_ids.remote()) \ + ) # type: ignore node_workers = defaultdict(list) # node id -> list of worker ranks node_gpus = defaultdict(list) # node id -> list of gpu ids @@ -329,7 +335,6 @@ def _run_workers( async_run_tensor_parallel_workers_only: bool = False, all_args: Optional[List[Tuple[Any, ...]]] = None, all_kwargs: Optional[List[Dict[str, Any]]] = None, - use_dummy_driver: bool = False, max_concurrent_workers: Optional[int] = None, **kwargs, ) -> Any: @@ -389,18 +394,10 @@ def _run_workers( driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] # Start the driver worker after all the ray workers. - if not use_dummy_driver: - driver_worker_output = [ - self.driver_worker.execute_method(method, *driver_args, - **driver_kwargs) - ] - else: - assert self.driver_dummy_worker is not None - driver_worker_output = [ - ray.get( - self.driver_dummy_worker.execute_method.remote( - method, *driver_args, **driver_kwargs)) - ] + driver_worker_output = [ + self.driver_worker.execute_method(method, *driver_args, + **driver_kwargs) + ] # Get the results of the ray workers. if self.workers: diff --git a/vllm/executor/ray_hpu_executor.py b/vllm/executor/ray_hpu_executor.py index 91c84d9214a60..f3025cb537ab8 100644 --- a/vllm/executor/ray_hpu_executor.py +++ b/vllm/executor/ray_hpu_executor.py @@ -163,9 +163,14 @@ def sort_by_driver_then_worker_ip(worker): # node will be placed first. self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip) - # Get the set of GPU IDs used on each node. - worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", - use_dummy_driver=True) + worker_node_and_gpu_ids = [] + for worker in [self.driver_dummy_worker] + self.workers: + if worker is None: + # driver_dummy_worker can be None when using ray spmd worker. + continue + worker_node_and_gpu_ids.append( + ray.get(worker.get_node_and_gpu_ids.remote()) \ + ) # type: ignore node_workers = defaultdict(list) # node id -> list of worker ranks node_gpus = defaultdict(list) # node id -> list of gpu ids @@ -296,7 +301,6 @@ def _run_workers( async_run_tensor_parallel_workers_only: bool = False, all_args: Optional[List[Tuple[Any, ...]]] = None, all_kwargs: Optional[List[Dict[str, Any]]] = None, - use_dummy_driver: bool = False, max_concurrent_workers: Optional[int] = None, **kwargs, ) -> Any: @@ -356,18 +360,10 @@ def _run_workers( driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] # Start the driver worker after all the ray workers. - if not use_dummy_driver: - driver_worker_output = [ - self.driver_worker.execute_method(method, *driver_args, - **driver_kwargs) - ] - else: - assert self.driver_dummy_worker is not None - driver_worker_output = [ - ray.get( - self.driver_dummy_worker.execute_method.remote( - method, *driver_args, **driver_kwargs)) - ] + driver_worker_output = [ + self.driver_worker.execute_method(method, *driver_args, + **driver_kwargs) + ] # Get the results of the ray workers. if self.workers: diff --git a/vllm/executor/ray_tpu_executor.py b/vllm/executor/ray_tpu_executor.py index 3ee59397bf4c9..5118c13934f0d 100644 --- a/vllm/executor/ray_tpu_executor.py +++ b/vllm/executor/ray_tpu_executor.py @@ -137,8 +137,14 @@ def sort_by_driver_then_worker_ip(worker): self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip) # Get the set of TPU IDs used on each node. - worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", - use_dummy_driver=True) + worker_node_and_gpu_ids = [] + for worker in [self.driver_dummy_worker] + self.workers: + if worker is None: + # driver_dummy_worker can be None when using ray spmd worker. + continue + worker_node_and_gpu_ids.append( + ray.get(worker.get_node_and_gpu_ids.remote()) \ + ) # type: ignore node_workers = defaultdict(list) for i, (node_id, _) in enumerate(worker_node_and_gpu_ids): @@ -199,7 +205,6 @@ def _run_workers( async_run_remote_workers_only: bool = False, all_args: Optional[List[Tuple[Any, ...]]] = None, all_kwargs: Optional[List[Dict[str, Any]]] = None, - use_dummy_driver: bool = False, max_concurrent_workers: Optional[int] = None, use_ray_compiled_dag: bool = False, **kwargs, @@ -241,14 +246,8 @@ def _run_workers( driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] # Start the driver worker after all the ray workers. - if not use_dummy_driver: - driver_worker_output = self.driver_worker.execute_method( - method, *driver_args, **driver_kwargs) - else: - assert self.driver_dummy_worker is not None - driver_worker_output = ray.get( - self.driver_dummy_worker.execute_method.remote( - method, *driver_args, **driver_kwargs)) + driver_worker_output = self.driver_worker.execute_method( + method, *driver_args, **driver_kwargs) # Get the results of the ray workers. if self.workers: ray_worker_outputs = ray.get(ray_worker_outputs) diff --git a/vllm/executor/ray_xpu_executor.py b/vllm/executor/ray_xpu_executor.py index 61f5d6a65e999..d2086f5fef26c 100644 --- a/vllm/executor/ray_xpu_executor.py +++ b/vllm/executor/ray_xpu_executor.py @@ -1,6 +1,8 @@ import asyncio from typing import List, Optional +import ray + import vllm.envs as envs from vllm.executor.ray_gpu_executor import RayGPUExecutor, RayGPUExecutorAsync from vllm.executor.xpu_executor import XPUExecutor @@ -14,8 +16,13 @@ class RayXPUExecutor(RayGPUExecutor, XPUExecutor): def _get_env_vars_to_be_updated(self): # Get the set of GPU IDs used on each node. - worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", - use_dummy_driver=True) + worker_node_and_gpu_ids = [] + for worker in [self.driver_dummy_worker] + self.workers: + if worker is None: + # driver_dummy_worker can be None when using ray spmd worker. + continue + worker_node_and_gpu_ids.append( + ray.get(worker.get_node_and_gpu_ids.remote())) # type: ignore # Set environment variables for the driver and workers. all_args_to_update_environment_variables = [({ From fd57d2b5347e8fe6da9287553d4b5a3aaf2e6693 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 8 Dec 2024 03:05:21 -0800 Subject: [PATCH 03/19] [torch.compile] allow candidate compile sizes (#10984) Signed-off-by: youkaichao --- tests/engine/test_arg_utils.py | 8 +++---- vllm/config.py | 44 +++++++++++++++++----------------- vllm/engine/arg_utils.py | 5 +--- vllm/entrypoints/llm.py | 6 +---- 4 files changed, 28 insertions(+), 35 deletions(-) diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index de78d41ad12eb..4e269de9fc40b 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -50,12 +50,12 @@ def test_compilation_config(): args = parser.parse_args(["-O=3"]) assert args.compilation_config.level == 3 - # set to json - args = parser.parse_args(["--compilation-config", '{"level": 3}']) + # set to string form of a dict + args = parser.parse_args(["--compilation-config", "{'level': 3}"]) assert args.compilation_config.level == 3 - # set to json - args = parser.parse_args(['--compilation-config={"level": 3}']) + # set to string form of a dict + args = parser.parse_args(["--compilation-config={'level': 3}"]) assert args.compilation_config.level == 3 diff --git a/vllm/config.py b/vllm/config.py index d1c4f995ad015..164622b5af34e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,3 +1,4 @@ +import ast import copy import enum import hashlib @@ -2191,14 +2192,10 @@ class CompilationConfig(BaseModel): - use_inductor: whether to use inductor compilation. - False: inductor compilation is not used. graph runs in eager. - True: inductor compilation is used. one graph for symbolic shape - is compiled. In addition, compile for different sizes specified - in inductor_compile_sizes, using configurations + is compiled. In addition, compile for cudagraph sizes that are + in candidate_compile_sizes, using configurations in inductor_compile_config. - - inductor_compile_sizes: sizes to compile for inductor. - - inductor_specialize_for_cudagraph_no_more_than: an optional integer - to specialize inductor for cudagraph sizes no more than the - specified size. It is useful when we want to specialize inductor - with a subset of cudagraph sizes. + - candidate_compile_sizes: sizes to compile for inductor. - inductor_compile_config: additional configurations for inductor. - None: use default configurations. - inductor_passes: additional passes for inductor. It is a dictionary @@ -2227,8 +2224,7 @@ class CompilationConfig(BaseModel): ]) use_inductor: bool = True - inductor_specialize_for_cudagraph_no_more_than: Optional[int] = None - inductor_compile_sizes: Optional[List[int]] = Field(default=None) + candidate_compile_sizes: Optional[List[int]] = Field(default=None) inductor_compile_config: Dict = Field(default_factory=dict) inductor_passes: Dict[str, str] = Field(default_factory=dict) @@ -2294,7 +2290,9 @@ def from_cli(cls, cli_value: str) -> "CompilationConfig": """Parse the CLI value for the compilation config.""" if cli_value in ["0", "1", "2", "3"]: return cls(level=int(cli_value)) - return CompilationConfig.model_validate_json(cli_value) + # do not use `eval`, it is dangerous and can execute arbitrary code + dict_value = ast.literal_eval(cli_value) + return CompilationConfig.model_validate(dict_value) def model_post_init(self, __context: Any) -> None: @@ -2355,18 +2353,20 @@ def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]): logger.info(("cudagraph sizes specified by model runner" " %s is overridden by config %s"), sizes_to_specialize, self.cudagraph_capture_sizes) - if self.inductor_specialize_for_cudagraph_no_more_than is not None: - assert self.inductor_compile_sizes is None, ( - "inductor_compile_sizes should be None when " - "inductor_specialize_for_cudagraph_no_more_than is not None") - self.compile_sizes = [ - x for x in self.capture_sizes - if x <= self.inductor_specialize_for_cudagraph_no_more_than - ] - else: - if self.inductor_compile_sizes is None: - self.inductor_compile_sizes = [] - self.compile_sizes = self.inductor_compile_sizes + + if self.candidate_compile_sizes is None: + self.candidate_compile_sizes = [] + self.compile_sizes = [ + x for x in self.candidate_compile_sizes if x in self.capture_sizes + ] + ignored_sizes = [ + x for x in self.candidate_compile_sizes + if x not in self.capture_sizes + ] + if ignored_sizes: + logger.warning(("candidate_compile_sizes %s are ignored " + "because they are not cudagraph capture sizes."), + ignored_sizes) # sort to make sure cudagraph capture sizes are in descending order self.capture_sizes.sort(reverse=True) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ccd9fac225cba..96c11ec2b4f9e 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -209,12 +209,9 @@ def __post_init__(self): # support `EngineArgs(compilation_config={...})` # without having to manually construct a # CompilationConfig object - if isinstance(self.compilation_config, (int)): + if isinstance(self.compilation_config, (int, dict)): self.compilation_config = CompilationConfig.from_cli( str(self.compilation_config)) - elif isinstance(self.compilation_config, (dict)): - self.compilation_config = CompilationConfig.from_cli( - json.dumps(self.compilation_config)) # Setup plugins from vllm.plugins import load_general_plugins diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 65fa9873df28c..8de30ccd18a11 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,5 +1,4 @@ import itertools -import json import warnings from contextlib import contextmanager from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type, @@ -186,12 +185,9 @@ def __init__( kwargs["disable_log_stats"] = True if compilation_config is not None: - if isinstance(compilation_config, (int)): + if isinstance(compilation_config, (int, dict)): compilation_config_instance = CompilationConfig.from_cli( str(compilation_config)) - elif isinstance(compilation_config, (dict)): - compilation_config_instance = CompilationConfig.from_cli( - json.dumps(compilation_config)) else: compilation_config_instance = compilation_config else: From a11f3265282c712d1d9fa75368e2a8c40019fbb7 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Sun, 8 Dec 2024 04:50:51 -0800 Subject: [PATCH 04/19] [V1] Initial support of multimodal models for V1 re-arch (#10699) Signed-off-by: Roger Wang --- vllm/engine/arg_utils.py | 16 +-- vllm/model_executor/models/interfaces.py | 5 + vllm/model_executor/models/internvl.py | 68 ++++++++++--- vllm/model_executor/models/molmo.py | 72 ++++++++++++-- vllm/model_executor/models/pixtral.py | 121 +++++++++++++++++------ vllm/model_executor/models/utils.py | 28 +++++- vllm/multimodal/inputs.py | 3 +- vllm/multimodal/utils.py | 10 +- vllm/v1/core/scheduler.py | 4 +- vllm/v1/engine/llm_engine.py | 24 ++++- vllm/v1/engine/mm_input_mapper.py | 2 +- 11 files changed, 284 insertions(+), 69 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 96c11ec2b4f9e..3db069ec64ee4 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1050,9 +1050,12 @@ def create_engine_config(self, # long context (> 32K) models. This is to avoid OOM errors in the # initial memory profiling phase. - # Chunked prefill is currently disabled for multimodal models by - # default. - if use_long_context and not model_config.is_multimodal_model: + # For multimodal models, chunked prefill is disabled by default in + # V0, but enabled by design in V1 + if model_config.is_multimodal_model: + self.enable_chunked_prefill = bool(envs.VLLM_USE_V1) + + elif use_long_context: is_gpu = device_config.device_type == "cuda" use_sliding_window = (model_config.get_sliding_window() is not None) @@ -1241,12 +1244,9 @@ def _override_v1_engine_config(self, engine_config: VllmConfig) -> None: Override the EngineConfig's configs based on the usage context for V1. """ assert envs.VLLM_USE_V1, "V1 is not enabled" - # TODO (ywang96): Enable APC by default when VLM supports it. if engine_config.model_config.is_multimodal_model: - logger.warning( - "Prefix caching is currently not supported for multimodal " - "models and has been disabled.") - engine_config.cache_config.enable_prefix_caching = False + # TODO (ywang96): Enable APC by default when VLM supports it. + assert not engine_config.cache_config.enable_prefix_caching @dataclass diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 01a381381ccec..c3979eab905db 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -36,6 +36,11 @@ def get_multimodal_embeddings(self, **kwargs) -> Optional[T]: """ Returns multimodal embeddings generated from multimodal kwargs to be merged with text embeddings. + + The output embeddings must be one of the following formats: + - A list or tuple of 2D tensors, where each tensor corresponds to + each input image. + - A single 3D tensor, with the batch dimension grouping the 2D tensors. """ ... diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index d5a7781fecfc3..42c769f79e202 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -26,7 +26,7 @@ InternVisionPatchModel) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs -from vllm.multimodal.inputs import NestedTensors +from vllm.multimodal.inputs import NestedTensors, PlaceholderRange from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of @@ -52,12 +52,18 @@ class InternVLImagePixelInputs(TypedDict): Shape: `(batch_size * num_images * (1 + num_patches), num_channels, height, width)` """ + patches_per_image: List[int] + """ + List of number of total patches for each image in the batch. + """ class InternVLImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] - data: torch.Tensor - """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` + data: NestedTensors + """ + A tensor of shape `(num_images, total_image_feature_size, hidden_size)` + or a list of tensors of shape `(total_image_feature_size, hidden_size)` `hidden_size` must match the hidden size of language model backbone. """ @@ -349,10 +355,32 @@ def input_processor( new_prompt = self._expand_image_prompt(prompt, image_feature_sizes, num_patches) new_prompt_token_ids = tokenizer.encode(new_prompt) + img_context_token_id = tokenizer.encode(self.img_context_token, + add_special_tokens=False) + assert len(img_context_token_id) == 1, \ + (f"Invalid image token '{self.img_context_token}': A valid image " + f"token encodes to a single token ID, got {img_context_token_id}.") + img_context_token_id = img_context_token_id[0] + + # Get precise tracking of placeholder positions + token_idx = image_idx = 0 + placeholder_ranges = [] + while token_idx < len(new_prompt_token_ids): + if new_prompt_token_ids[token_idx] == img_context_token_id: + curr_image_featue_size = image_feature_sizes[image_idx] + placeholder_ranges.append( + PlaceholderRange(offset=token_idx, + length=curr_image_featue_size)) + image_idx += 1 + token_idx += curr_image_featue_size + else: + token_idx += 1 - return token_inputs(prompt=prompt, - prompt_token_ids=new_prompt_token_ids, - multi_modal_data=multi_modal_data) + return token_inputs( + prompt=prompt, + prompt_token_ids=new_prompt_token_ids, + multi_modal_data=multi_modal_data, + multi_modal_placeholders={"image": placeholder_ranges}) def input_mapper( self, @@ -614,26 +642,46 @@ def _parse_and_validate_image_input( if not isinstance(pixel_values, (torch.Tensor, list)): raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") + + patches_per_image = [] + for request_pixel_values in pixel_values: + for image_pixel_values in request_pixel_values: + patches_per_image.append(image_pixel_values.shape[0]) # We need to flatten (B, N, P) to (B*N*P), # so we call flatten_bn twice. return InternVLImagePixelInputs( type="pixel_values", data=self._validate_pixel_values( flatten_bn(flatten_bn(pixel_values), concat=True)), - ) + patches_per_image=patches_per_image) raise AssertionError("This line should be unreachable.") def _process_image_input( self, image_input: InternVLImageInputs, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor]: if image_input["type"] == "image_embeds": return image_input["data"] assert self.vision_model is not None + image_embeds = self.extract_feature(image_input["data"]) + patches_per_image = image_input["patches_per_image"] + if len(patches_per_image) == 1: + image_embeds = image_embeds.unsqueeze(0) + return image_embeds + + # NOTE: Image embeddings are split into separate tensors for each image + # by the size of each embedding. + feature_size = image_embeds.shape[1] + image_embeds = image_embeds.view(-1, + self.config.text_config.hidden_size) + image_feature_sizes = [ + num_patches * feature_size for num_patches in patches_per_image + ] + image_embeds = image_embeds.split(image_feature_sizes) return image_embeds def _set_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor: @@ -696,13 +744,11 @@ def forward( "inputs_embeds": inputs_embeds, } + # Only required if the model is mono-architecture if self.visual_token_mask is not None: - # overwrite visual_token_mask and img_context_token_id back to None, - # so that this doesn't need to depend on encoder output forward_kwargs.update( {"visual_token_mask": self.visual_token_mask}) self.visual_token_mask = None - self.img_context_token_id = None hidden_states = self.language_model.model(**forward_kwargs) return hidden_states diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index d1fcbd167c199..a328b5a2aeea7 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -37,7 +37,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs -from vllm.multimodal.inputs import NestedTensors +from vllm.multimodal.inputs import NestedTensors, PlaceholderRange from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, SequenceData) @@ -46,12 +46,16 @@ from .interfaces import SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) + maybe_prefix, merge_multimodal_embeddings) # TODO: hard-coded for now. Consider making it configurable. VIT_LAYERS = [-2, -9] NUM_PREFIX_TOKENS = 1 ADDITIONAL_VOCAB_SIZE = 128 +DEFAULT_IMAGE_PATCH_TOKEN_ID = 152066 +DEFAULT_IM_START_TOKEN_ID = 152067 +DEFAULT_IM_END_TOKEN_ID = 152064 +DEFAULT_IM_COL_TOKEN_ID = 152065 class MolmoImageInputs(TypedDict): @@ -75,6 +79,11 @@ class MolmoImageInputs(TypedDict): `(batch_size, num_crops, num_patch)` """ + image_start_end: Tuple[int, int] + """Starting and ending index of placeholder + tokens + """ + @dataclass class VisionBackboneConfig: @@ -918,6 +927,8 @@ def image_input_mapper_for_molmo( ctx: InputContext, data: object, ): + if isinstance(data, list): + data = data[0] return MultiModalKwargs(data) @@ -967,7 +978,22 @@ def dummy_data_for_molmo(ctx: InputContext, seq_len: int, if "image_masks" in out: dummy_imgdata["image_masks"] = out["image_masks"] dummy_imgdata["seq_len"] = torch.tensor(seq_len, dtype=torch.long) - return DummyData(dummy_seqdata, {"image": dummy_imgdata}) + size = 0 + offset = -1 + for i in range(len(token_ids)): + if token_ids[i] in (DEFAULT_IMAGE_PATCH_TOKEN_ID, + DEFAULT_IM_START_TOKEN_ID, DEFAULT_IM_END_TOKEN_ID, + DEFAULT_IM_COL_TOKEN_ID): + if offset < 0: + offset = i + size += 1 + dummy_imgdata["image_start_end"] = (offset, offset + size) + return DummyData(seq_data=dummy_seqdata, + multi_modal_data={"image": dummy_imgdata}, + multi_modal_placeholders={ + "image": + [PlaceholderRange(offset=offset, length=size)] + }) def pad_images( @@ -1055,19 +1081,34 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs): if image_masks is not None: image_data["image_masks"] = image_masks - image_data["seq_len"] = torch.tensor(len(out["input_ids"]), + new_prompt_token_ids = out["input_ids"].tolist() + image_data["seq_len"] = torch.tensor(len(new_prompt_token_ids), dtype=torch.long) multi_modal_data = dict(image=image_data) + size = 0 + offset = -1 + for i in range(len(new_prompt_token_ids)): + if new_prompt_token_ids[i] in (DEFAULT_IMAGE_PATCH_TOKEN_ID, + DEFAULT_IM_START_TOKEN_ID, + DEFAULT_IM_END_TOKEN_ID, + DEFAULT_IM_COL_TOKEN_ID): + if offset < 0: + offset = i + size += 1 + image_data["image_start_end"] = (offset, offset + size) prompt = inputs.get("prompt") if prompt is None: - prompt = tokenizer.decode(out["input_ids"]) + prompt = tokenizer.decode(new_prompt_token_ids) return token_inputs( - prompt_token_ids=out["input_ids"], + prompt_token_ids=new_prompt_token_ids, prompt=prompt, multi_modal_data=multi_modal_data, + multi_modal_placeholders={ + "image": [PlaceholderRange(offset=offset, length=size)] + }, ) @@ -1113,6 +1154,7 @@ def _parse_and_validate_image_input( ) -> Optional[MolmoImageInputs]: images = kwargs.pop("images", None) image_masks = kwargs.pop("image_masks", None) + image_start_end = kwargs.pop("image_start_end", None) if images is None: return None @@ -1130,6 +1172,7 @@ def _parse_and_validate_image_input( image_input_idx=image_input_idx, seq_len=seq_len, image_masks=image_masks, + image_start_end=image_start_end, ) def _process_image_input( @@ -1178,9 +1221,16 @@ def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: # Note: In this original implementation from AI2, the final # vision_embeddings will be always be the same length - # of input embedddings, which is not very efficient. - # TODO(ywang96): see if this can be optimized. + # of input embeddings. vision_embeddings = torch.einsum('nd,nm->md', image_features, mat) + + # Split by the sizes of the input sequences. For each full embedding, + # extract the actual vision embeddings to be merged. + vision_embeddings = list(vision_embeddings.split(seq_len.tolist())) + for i in range(len(vision_embeddings)): + start, end = image_input['image_start_end'][i] + vision_embeddings[i] = vision_embeddings[i][start:end] + return vision_embeddings def get_input_embeddings( @@ -1190,7 +1240,11 @@ def get_input_embeddings( ) -> torch.Tensor: inputs_embeds = self.model.get_input_embeddings(input_ids) if multimodal_embeddings is not None: - inputs_embeds = inputs_embeds + multimodal_embeddings + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, [ + DEFAULT_IMAGE_PATCH_TOKEN_ID, DEFAULT_IM_START_TOKEN_ID, + DEFAULT_IM_END_TOKEN_ID, DEFAULT_IM_COL_TOKEN_ID + ]) return inputs_embeds def forward( diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 215727cadd954..c6786c363ab4a 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -48,6 +48,9 @@ except ImportError: USE_XFORMERS_OPS = False +PIXTRAL_IMAGE_BREAK_ID = 12 +PIXTRAL_IMAGE_END_ID = 13 + def get_max_pixtral_image_tokens(ctx: InputContext): tokenizer = cached_get_tokenizer( @@ -68,7 +71,6 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int, tokenizer_mode=ctx.model_config.tokenizer_mode) mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder - patch_size = mm_encoder.mm_config.image_patch_size image_token_id = mm_encoder.special_ids.img mm_config = ctx.model_config.multimodal_config @@ -78,8 +80,8 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int, size = 256 image = Image.new("RGB", (size, size), color=0) - image_feature_size = (size**2) // (patch_size**2) - + encoding = tokenizer.instruct.mm_encoder(ImageChunk(image=image)) + image_feature_size = len(encoding.tokens) num_image_tokens = image_feature_size * num_images seq_data = SequenceData.from_prompt_token_counts( (image_token_id, num_image_tokens), @@ -101,14 +103,13 @@ def input_mapper_for_pixtral(ctx: InputContext, Args: ctx: Context of the loaded model. - data: data potentially containing image/image embeddings to be mapped - to pixel_values in .forward() for a visual QWenLMHeadModel model. + data: data potentially containing PIL images to be processed + and mapped to `images`. Returns: MultiModalKwargs containing the stacked normalized images tensor or image embeddings. """ - # Early exit if we have provided an image to a language only Qwen model model_config = ctx.model_config tokenizer = cached_get_tokenizer( model_config.tokenizer, tokenizer_mode=model_config.tokenizer_mode) @@ -116,35 +117,67 @@ def input_mapper_for_pixtral(ctx: InputContext, data_list = data if isinstance(data, list) else [data] images = [] + image_tokens_list = [] for image_data in data_list: image = ImageChunk(image=image_data) encoding = tokenizer.instruct.mm_encoder(image) image = torch.from_numpy(encoding.image).to(device="cuda", dtype=torch.float16) images.append(image) + image_tokens_list.append(encoding.tokens) - return MultiModalKwargs({"images": images}) + image_tokens = torch.tensor([ + token_id for image_tokens in image_tokens_list + for token_id in image_tokens + ]) + return MultiModalKwargs({"images": images, "image_tokens": image_tokens}) def input_processor_for_pixtral(ctx: InputContext, inputs: DecoderOnlyInputs): multi_modal_data = inputs.get("multi_modal_data") - if multi_modal_data is not None and "image" in multi_modal_data: - tokenizer = cached_get_tokenizer( - ctx.model_config.tokenizer, - tokenizer_mode=ctx.model_config.tokenizer_mode) - - mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder - image_token_id = mm_encoder.special_ids.img + if multi_modal_data is None or "image" not in multi_modal_data: + return inputs - if image_token_id not in inputs['prompt_token_ids']: - raise ValueError( - f"You've passed {inputs=} without {image_token_id=}" - " Make sure to process your input via mistral_common's" - " tokenizer or pass a chat completion request. For more" - " For more info, see: " - "https://github.com/vllm-project/vllm/issues/8411.") + prompt_token_ids = inputs.get("prompt_token_ids") + prompt = inputs.get("prompt") + tokenizer = cached_get_tokenizer( + ctx.model_config.tokenizer, + tokenizer_mode=ctx.model_config.tokenizer_mode) - return inputs + mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder + image_token_id = mm_encoder.special_ids.img + image_break_id = mm_encoder.special_ids.img_break + image_end_id = mm_encoder.special_ids.img_end + + if image_token_id not in inputs['prompt_token_ids']: + raise ValueError( + f"You've passed {inputs=} without {image_token_id=}" + " Make sure to process your input via mistral_common's" + " tokenizer or pass a chat completion request. For more" + " For more info, see: " + "https://github.com/vllm-project/vllm/issues/8411.") + + # Get precise tracking of placeholder positions + placeholder_ranges = [] + curr_offset = -1 + curr_length = 0 + for i in range(len(prompt_token_ids)): + if prompt_token_ids[i] in (image_token_id, image_break_id): + if curr_offset < 0: + curr_offset = i + curr_length += 1 + elif prompt_token_ids[i] == image_end_id: + curr_length += 1 + placeholder_ranges.append( + PlaceholderRange(offset=curr_offset, length=curr_length)) + curr_offset = -1 + curr_length = 0 + else: + pass + return token_inputs(prompt=prompt, + prompt_token_ids=prompt_token_ids, + multi_modal_data=multi_modal_data, + multi_modal_placeholders={"image": placeholder_ranges}) @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral) @@ -192,11 +225,29 @@ def sampler(self): return get_sampler() def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: - image_input = self._parse_and_validate_image_input(**kwargs) + image_input, image_tokens = self._parse_and_validate_image_input( + **kwargs) if image_input is None: return None + vision_embeddings = self._process_image_input(image_input) - return vision_embeddings + + # NOTE: We patch the outputs of the vision encoder with embeddings + # from `[IMG_BREAK]` and `[IMG_END]` tokens. + image_embeds = self.language_model.get_input_embeddings(image_tokens) + image_token_mask = image_tokens == self.vision_args.image_token_id + image_embeds[image_token_mask] = vision_embeddings + + # NOTE: Image embeddings are split into separate tensors for each image + # by the indices of `[IMG_END]` token. + split_indices = torch.where( + image_tokens == PIXTRAL_IMAGE_END_ID)[0] + 1 + if len(split_indices) <= 1: + # Do not split, return as tensor of shape [1, fs, hs] + return image_embeds.unsqueeze(0) + + image_embeds = image_embeds.tensor_split(split_indices.cpu()) + return image_embeds def get_input_embeddings( self, @@ -206,8 +257,10 @@ def get_input_embeddings( inputs_embeds = self.language_model.get_input_embeddings(input_ids) if multimodal_embeddings is not None: inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.vision_args.image_token_id) + input_ids, inputs_embeds, multimodal_embeddings, [ + self.vision_args.image_token_id, PIXTRAL_IMAGE_END_ID, + PIXTRAL_IMAGE_BREAK_ID + ]) return inputs_embeds def forward( @@ -245,10 +298,11 @@ def forward( def _parse_and_validate_image_input( self, images: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor], - torch.Tensor]] = None + torch.Tensor]] = None, + image_tokens: Optional[torch.Tensor] = None, ) -> Optional[List[torch.Tensor]]: if images is None: - return None + return None, None if isinstance(images, torch.Tensor): # if passed as batch take all images @@ -267,7 +321,16 @@ def _parse_and_validate_image_input( images = flatten_images - return images + if isinstance(image_tokens, torch.Tensor): + # image_tokens are batched + image_tokens = image_tokens.flatten() + elif isinstance(image_tokens, list): + # image_tokens are of different lengths thus passed as a list + image_tokens = torch.cat(image_tokens) + + assert image_tokens.dim() == 1 + + return images, image_tokens def _process_image_input(self, image_input: List[torch.Tensor]) -> torch.Tensor: diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 5ec44955dbd80..269b66806adf4 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -409,16 +409,42 @@ def merge_multimodal_embeddings( input_ids: torch.Tensor, inputs_embeds: torch.Tensor, multimodal_embeddings: NestedTensors, - placeholder_token_id: int, + placeholder_token_id: Union[int, List[int]], ) -> torch.Tensor: """ Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the positions in ``inputs_embeds`` corresponding to placeholder tokens in ``input_ids``. + + ``placeholder_token_id`` can be a list of token ids (e.g, token ids + of img_start, img_break, and img_end tokens) when needed: This means + the order of these tokens in the ``input_ids`` MUST MATCH the order of + their embeddings in ``multimodal_embeddings`` since we need to + slice-merge instead of individually scattering. + + For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where + - T is text token + - S is image start token + - I is image embedding token + - B is image break token + - E is image end token. + + Then the image embeddings (that correspond to I's) from vision encoder + must be padded with embeddings of S, B, and E in the same order of + input_ids for a correct embedding merge. Note: This updates ``inputs_embeds`` in place. """ + if isinstance(placeholder_token_id, list): + placeholder_token_id = torch.tensor(placeholder_token_id, + device=input_ids.device) + return _merge_multimodal_embeddings( + inputs_embeds, + torch.isin(input_ids, placeholder_token_id), + multimodal_embeddings, + ) + return _merge_multimodal_embeddings( inputs_embeds, (input_ids == placeholder_token_id), diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 640c7c04b8817..229a8fbdf5831 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -96,7 +96,8 @@ class PlaceholderRange(TypedDict): """The length of the placeholder.""" -NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor] +NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor, + Tuple[torch.Tensor, ...]] """ Uses a list instead of a tensor if the dimensions of each element do not match. """ diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index d4333b7519b47..c898ca4e6573e 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -535,11 +535,13 @@ def repeat_and_pad_placeholder_tokens( return new_prompt, new_token_ids, placeholder_ranges -def consecutive_placeholder_ranges(num_items: int, - item_size: int) -> List[PlaceholderRange]: +def consecutive_placeholder_ranges( + num_items: int, + item_size: int, + initial_offset: int = 0) -> List[PlaceholderRange]: """Returns a list of consecutive PlaceholderRanges of a fixed size""" return [ - PlaceholderRange(offset=i * item_size, length=item_size) - for i in range(num_items) + PlaceholderRange(offset=initial_offset + i * item_size, + length=item_size) for i in range(num_items) ] diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index f1f26f4e8d443..1203d35fc985f 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -73,12 +73,12 @@ def __init__( # has the Transformer architecture (e.g., ViT). # FIXME(woosuk): Below are placeholder values. We need to calculate the # actual values from the configurations. - self.max_num_encoder_input_tokens = 2048 + self.max_num_encoder_input_tokens = 16384 # NOTE(woosuk): For the models without encoder (e.g., text-only models), # the encoder cache will not be initialized and used, regardless of # the cache size. This is because the memory space for the encoder cache # is preallocated in the profiling run. - self.encoder_cache_manager = EncoderCacheManager(cache_size=2048) + self.encoder_cache_manager = EncoderCacheManager(cache_size=16384) def schedule(self) -> "SchedulerOutput": # NOTE(woosuk) on the scheduling algorithm: diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 312c0242a45dd..994e68669108e 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -1,5 +1,7 @@ from typing import Dict, List, Mapping, Optional, Type, Union +from typing_extensions import TypeVar + from vllm.config import VllmConfig from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics_types import StatLoggerBase @@ -12,7 +14,8 @@ from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs +from vllm.transformers_utils.tokenizer_group import ( + BaseTokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import UsageContext from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.detokenizer import Detokenizer @@ -21,6 +24,8 @@ logger = init_logger(__name__) +_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) + class LLMEngine: """Legacy LLMEngine for backwards compatibility.""" @@ -169,5 +174,18 @@ def start_profile(self): def stop_profile(self): self.engine_core.profile(False) - def get_tokenizer_group(self, group_type): - pass + def get_tokenizer_group( + self, + group_type: Type[_G] = BaseTokenizerGroup, + ) -> _G: + tokenizer_group = self.tokenizer + + if tokenizer_group is None: + raise ValueError("Unable to get tokenizer because " + "skip_tokenizer_init is True") + if not isinstance(tokenizer_group, group_type): + raise TypeError("Invalid type of tokenizer group. " + f"Expected type: {group_type}, but " + f"found type: {type(tokenizer_group)}") + + return tokenizer_group diff --git a/vllm/v1/engine/mm_input_mapper.py b/vllm/v1/engine/mm_input_mapper.py index 45882f8f076d4..7ad6882b04520 100644 --- a/vllm/v1/engine/mm_input_mapper.py +++ b/vllm/v1/engine/mm_input_mapper.py @@ -33,7 +33,7 @@ def process_inputs( num_images = len(image_inputs) for i in range(num_images): mm_input = self.multi_modal_input_mapper( - {"image": [image_inputs[i]]}, + {"image": image_inputs[i]}, mm_processor_kwargs=mm_processor_kwargs, ) mm_inputs.append(mm_input) From 43b05fa314e90e551d87211e8bdde2e2bb5a0bdc Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 8 Dec 2024 11:18:18 -0800 Subject: [PATCH 05/19] [torch.compile][misc] fix comments (#10993) Signed-off-by: youkaichao --- vllm/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 164622b5af34e..38cf642b23cda 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2177,8 +2177,8 @@ class CompilationConfig(BaseModel): TODO: move outside cudagraph logic into compilation. torch.compile will handle cudagraph capture logic in the future. - cudagraph_capture_sizes: sizes to capture cudagraph. - - None: capture sizes are inferred from compilation context. - - List[int]: capture sizes are specified. + - None (default): capture sizes are inferred from vllm config. + - List[int]: capture sizes are specified as given. - cudagraph_num_of_warmups: number of warmup runs for cudagraph. It means the first several runs will be treated as warmup runs. Only after that, the execution will be recorded, and the recorded From 46004e83a2e0b908f28099d93171bfb4934e4722 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 8 Dec 2024 17:28:27 -0800 Subject: [PATCH 06/19] [misc] clean up and unify logging (#10999) Signed-off-by: youkaichao --- vllm/config.py | 73 ++++++++++++++++++--------------------- vllm/engine/llm_engine.py | 54 ++--------------------------- 2 files changed, 37 insertions(+), 90 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 38cf642b23cda..7fbe04eaaf4f8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2579,45 +2579,40 @@ def __post_init__(self): self.instance_id = random_uuid()[:5] def __str__(self): - return ("model=%r, speculative_config=%r, tokenizer=%r, " - "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " - "override_neuron_config=%s, tokenizer_revision=%s, " - "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " - "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " - "pipeline_parallel_size=%d, " - "disable_custom_all_reduce=%s, quantization=%s, " - "enforce_eager=%s, kv_cache_dtype=%s, " - "quantization_param_path=%s, device_config=%s, " - "decoding_config=%r, observability_config=%r, " - "seed=%d, served_model_name=%s, " - "num_scheduler_steps=%d, enable_prefix_caching=%s, " - "use_async_output_proc=%s, mm_processor_kwargs=%s") % \ - (self.model_config.model, self.speculative_config, - self.model_config.tokenizer, - self.model_config.skip_tokenizer_init, - self.model_config.tokenizer_mode, - self.model_config.revision, - self.model_config.override_neuron_config, - self.model_config.tokenizer_revision, - self.model_config.trust_remote_code, - self.model_config.dtype, - self.model_config.max_model_len, - self.load_config.download_dir, - self.load_config.load_format, - self.parallel_config.tensor_parallel_size, - self.parallel_config.pipeline_parallel_size, - self.parallel_config.disable_custom_all_reduce, - self.model_config.quantization, - self.model_config.enforce_eager, - self.cache_config.cache_dtype, - self.model_config.quantization_param_path, - self.device_config.device, self.decoding_config, - self.observability_config, self.model_config.seed, - self.model_config.served_model_name, - self.scheduler_config.num_scheduler_steps, - self.cache_config.enable_prefix_caching, - self.model_config.use_async_output_proc, - self.model_config.mm_processor_kwargs) + return ( + f"model={self.model_config.model!r}," + f" speculative_config={self.speculative_config!r}," + f" tokenizer={self.model_config.tokenizer!r}, " + f"skip_tokenizer_init={self.model_config.skip_tokenizer_init}," + f" tokenizer_mode={self.model_config.tokenizer_mode}, " + f"revision={self.model_config.revision}, " + f"override_neuron_config={self.model_config.override_neuron_config}," + f" tokenizer_revision={self.model_config.tokenizer_revision}, " + f"trust_remote_code={self.model_config.trust_remote_code}, " + f"dtype={self.model_config.dtype}, " + f"max_seq_len={self.model_config.max_model_len}," + f" download_dir={self.load_config.download_dir!r}, " + f"load_format={self.load_config.load_format}, " + f"tensor_parallel_size={self.parallel_config.tensor_parallel_size}," + f" pipeline_parallel_size={self.parallel_config.pipeline_parallel_size}, " # noqa + f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, " # noqa + f"quantization={self.model_config.quantization}, " + f"enforce_eager={self.model_config.enforce_eager}, " + f"kv_cache_dtype={self.cache_config.cache_dtype}, " + f"quantization_param_path={self.model_config.quantization_param_path}," + f" device_config={self.device_config.device}, " + f"decoding_config={self.decoding_config!r}, " + f"observability_config={self.observability_config!r}, " + f"seed={self.model_config.seed}, " + f"served_model_name={self.model_config.served_model_name}, " + f"num_scheduler_steps={self.scheduler_config.num_scheduler_steps}, " + f"multi_step_stream_outputs={self.scheduler_config.multi_step_stream_outputs}, " # noqa + f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, " + f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa + f"use_async_output_proc={self.model_config.use_async_output_proc}, " + f"mm_processor_kwargs={self.model_config.mm_processor_kwargs}, " + f"pooler_config={self.model_config.pooler_config!r}," + f" compilation_config={self.compilation_config!r}") _current_vllm_config: Optional[VllmConfig] = None diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 26a8c94099a11..560f84a008291 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -247,60 +247,12 @@ def __init__( ) logger.info( - "Initializing an LLM engine (v%s) with config: " - "model=%r, speculative_config=%r, tokenizer=%r, " - "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " - "override_neuron_config=%s, tokenizer_revision=%s, " - "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " - "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " - "pipeline_parallel_size=%d, " - "disable_custom_all_reduce=%s, quantization=%s, " - "enforce_eager=%s, kv_cache_dtype=%s, " - "quantization_param_path=%s, device_config=%s, " - "decoding_config=%r, observability_config=%r, " - "seed=%d, served_model_name=%s, " - "num_scheduler_steps=%d, chunked_prefill_enabled=%s " - "multi_step_stream_outputs=%s, enable_prefix_caching=%s, " - "use_async_output_proc=%s, use_cached_outputs=%s, " - "mm_processor_kwargs=%s, pooler_config=%r," - "compilation_config=%r", + "Initializing an LLM engine (v%s) with config: %r," + "use_cached_outputs=%s, ", VLLM_VERSION, - self.model_config.model, - self.speculative_config, - self.model_config.tokenizer, - self.model_config.skip_tokenizer_init, - self.model_config.tokenizer_mode, - self.model_config.revision, - self.model_config.override_neuron_config, - self.model_config.tokenizer_revision, - self.model_config.trust_remote_code, - self.model_config.dtype, - self.model_config.max_model_len, - self.load_config.download_dir, - self.load_config.load_format, - self.parallel_config.tensor_parallel_size, - self.parallel_config.pipeline_parallel_size, - self.parallel_config.disable_custom_all_reduce, - self.model_config.quantization, - self.model_config.enforce_eager, - self.cache_config.cache_dtype, - self.model_config.quantization_param_path, - self.device_config.device, - self.decoding_config, - self.observability_config, - self.model_config.seed, - self.model_config.served_model_name, - self.scheduler_config.num_scheduler_steps, - self.scheduler_config.chunked_prefill_enabled, - self.scheduler_config.multi_step_stream_outputs, - self.cache_config.enable_prefix_caching, - self.model_config.use_async_output_proc, + vllm_config, use_cached_outputs, - self.model_config.mm_processor_kwargs, - self.model_config.pooler_config, - vllm_config.compilation_config, ) - # TODO(woosuk): Print more configs in debug mode. self.log_stats = log_stats self.use_cached_outputs = use_cached_outputs From af7c4a92e654684066e61518d6ed90feda983635 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Sun, 8 Dec 2024 22:29:16 -0800 Subject: [PATCH 07/19] [Doc][V1] Add V1 support column for multimodal models (#10998) Signed-off-by: Roger Wang --- docs/source/models/supported_models.rst | 26 ++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index c9b3fa8485ff1..4e5b10967e3bb 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -495,7 +495,7 @@ Text Generation --------------- .. list-table:: - :widths: 25 25 15 25 5 5 + :widths: 25 25 15 20 5 5 5 :header-rows: 1 * - Architecture @@ -504,47 +504,55 @@ Text Generation - Example HF Models - :ref:`LoRA ` - :ref:`PP ` + - V1 * - :code:`AriaForConditionalGeneration` - Aria - T + I - :code:`rhymes-ai/Aria` - - ✅︎ + - * - :code:`Blip2ForConditionalGeneration` - BLIP-2 - T + I\ :sup:`E` - :code:`Salesforce/blip2-opt-2.7b`, :code:`Salesforce/blip2-opt-6.7b`, etc. - - ✅︎ + - * - :code:`ChameleonForConditionalGeneration` - Chameleon - T + I - :code:`facebook/chameleon-7b` etc. - - ✅︎ + - * - :code:`FuyuForCausalLM` - Fuyu - T + I - :code:`adept/fuyu-8b` etc. - - ✅︎ + - * - :code:`ChatGLMModel` - GLM-4V - T + I - :code:`THUDM/glm-4v-9b` etc. - ✅︎ - ✅︎ + - * - :code:`H2OVLChatModel` - H2OVL - T + I\ :sup:`E+` - :code:`h2oai/h2ovl-mississippi-800m`, :code:`h2oai/h2ovl-mississippi-2b`, etc. - - ✅︎ + - * - :code:`Idefics3ForConditionalGeneration` - Idefics3 - T + I - :code:`HuggingFaceM4/Idefics3-8B-Llama3` etc. - ✅︎ + - - * - :code:`InternVLChatModel` - InternVL 2.5, Mono-InternVL, InternVL 2.0 @@ -552,96 +560,112 @@ Text Generation - :code:`OpenGVLab/InternVL2_5-4B`, :code:`OpenGVLab/Mono-InternVL-2B`, :code:`OpenGVLab/InternVL2-4B`, etc. - - ✅︎ + - ✅︎ * - :code:`LlavaForConditionalGeneration` - LLaVA-1.5 - T + I\ :sup:`E+` - :code:`llava-hf/llava-1.5-7b-hf`, :code:`TIGER-Lab/Mantis-8B-siglip-llama3` (see note), etc. - - ✅︎ + - ✅︎ * - :code:`LlavaNextForConditionalGeneration` - LLaVA-NeXT - T + I\ :sup:`E+` - :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc. - - ✅︎ + - * - :code:`LlavaNextVideoForConditionalGeneration` - LLaVA-NeXT-Video - T + V - :code:`llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. - - ✅︎ + - * - :code:`LlavaOnevisionForConditionalGeneration` - LLaVA-Onevision - T + I\ :sup:`+` + V\ :sup:`+` - :code:`llava-hf/llava-onevision-qwen2-7b-ov-hf`, :code:`llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. - - ✅︎ + - * - :code:`MiniCPMV` - MiniCPM-V - T + I\ :sup:`E+` - :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc. - ✅︎ - ✅︎ + - * - :code:`MllamaForConditionalGeneration` - Llama 3.2 - T + I\ :sup:`+` - :code:`meta-llama/Llama-3.2-90B-Vision-Instruct`, :code:`meta-llama/Llama-3.2-11B-Vision`, etc. - - + - * - :code:`MolmoForCausalLM` - Molmo - T + I - :code:`allenai/Molmo-7B-D-0924`, :code:`allenai/Molmo-72B-0924`, etc. - - ✅︎ + - ✅︎ * - :code:`NVLM_D_Model` - NVLM-D 1.0 - T + I\ :sup:`E+` - :code:`nvidia/NVLM-D-72B`, etc. - - ✅︎ + - ✅︎ * - :code:`PaliGemmaForConditionalGeneration` - PaliGemma - T + I\ :sup:`E` - :code:`google/paligemma-3b-pt-224`, :code:`google/paligemma-3b-mix-224`, etc. - - ✅︎ + - * - :code:`Phi3VForCausalLM` - Phi-3-Vision, Phi-3.5-Vision - T + I\ :sup:`E+` - :code:`microsoft/Phi-3-vision-128k-instruct`, :code:`microsoft/Phi-3.5-vision-instruct` etc. - - ✅︎ + - ✅︎ * - :code:`PixtralForConditionalGeneration` - Pixtral - T + I\ :sup:`+` - :code:`mistralai/Pixtral-12B-2409`, :code:`mistral-community/pixtral-12b` etc. - - ✅︎ + - ✅︎ * - :code:`QWenLMHeadModel` - Qwen-VL - T + I\ :sup:`E+` - :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc. - ✅︎ - ✅︎ + - * - :code:`Qwen2AudioForConditionalGeneration` - Qwen2-Audio - T + A\ :sup:`+` - :code:`Qwen/Qwen2-Audio-7B-Instruct` - - ✅︎ + - * - :code:`Qwen2VLForConditionalGeneration` - Qwen2-VL - T + I\ :sup:`E+` + V\ :sup:`E+` - :code:`Qwen/Qwen2-VL-2B-Instruct`, :code:`Qwen/Qwen2-VL-7B-Instruct`, :code:`Qwen/Qwen2-VL-72B-Instruct`, etc. - ✅︎ - ✅︎ + - * - :code:`UltravoxModel` - Ultravox - T + A\ :sup:`E+` - :code:`fixie-ai/ultravox-v0_3` - - ✅︎ + - | :sup:`E` Pre-computed embeddings can be inputted for this modality. | :sup:`+` Multiple items can be inputted per text prompt for this modality. From d1c2e15eb31ef12e688ce0cb71895f88eaf4cd4f Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 8 Dec 2024 23:09:04 -0800 Subject: [PATCH 08/19] [torch.compile] add dynamo time tracking (#11005) Signed-off-by: youkaichao --- vllm/compilation/backends.py | 6 ++++++ vllm/compilation/decorators.py | 6 +++--- vllm/compilation/monitor.py | 9 +++++++-- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 1206424ae1e3f..f002a8ff905b1 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -265,7 +265,13 @@ def configure_post_pass(self): def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: + # when dynamo calls the backend, it means the bytecode + # transform and analysis are done compilation_counter.num_graphs_seen += 1 + from .monitor import torch_compile_start_time + dynamo_time = time.time() - torch_compile_start_time + logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time) + self.compilation_configs.compilation_time += dynamo_time # we control the compilation process, each instance can only be # called once diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index a32dced57e5b3..938430fe2a501 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -145,6 +145,7 @@ def _support_torch_compile( def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs) + self.vllm_config = vllm_config # for CompilationLevel.DYNAMO_AS_IS , the upper level model runner # will handle the compilation, so we don't need to do anything here. self.do_not_compile = \ @@ -157,9 +158,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): TorchCompileWrapperWithCustomDispatcher.__init__( self, compilation_level=vllm_config.compilation_config.level) - if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE: - start_monitoring_torch_compile(vllm_config.compilation_config) - cls.__init__ = __init__ def __call__(self, *args, **kwargs): @@ -186,6 +184,8 @@ def __call__(self, *args, **kwargs): raise ValueError( "Unsupported dynamic dimensions" f" {dims} for argument {k} with type {type(arg)}.") + # here, it is the starting point of the `torch.compile` process + start_monitoring_torch_compile(self.vllm_config.compilation_config) # if we don't use custom dispatcher, we can directly call the # compiled function and let torch.compile handle the dispatching, diff --git a/vllm/compilation/monitor.py b/vllm/compilation/monitor.py index f718e46423212..3348674b09af2 100644 --- a/vllm/compilation/monitor.py +++ b/vllm/compilation/monitor.py @@ -1,14 +1,19 @@ +import time + from vllm.config import CompilationConfig, CompilationLevel from vllm.logger import init_logger logger = init_logger(__name__) +torch_compile_start_time: float = 0.0 + def start_monitoring_torch_compile(compilation_config: CompilationConfig): - pass + global torch_compile_start_time + torch_compile_start_time = time.time() def end_monitoring_torch_compile(compilation_config: CompilationConfig): if compilation_config.level == CompilationLevel.PIECEWISE: - logger.info("graph compilation takes %.2f s in total", + logger.info("torch.compile takes %.2f s in total", compilation_config.compilation_time) From c690357928fd2812f450bfb0c3629a816f5e9a55 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Mon, 9 Dec 2024 08:27:10 -0800 Subject: [PATCH 09/19] [V1] Fix Detokenizer loading in `AsyncLLM` (#10997) Signed-off-by: Roger Wang --- vllm/v1/engine/async_llm.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 4ef372fd8464b..0bcccda2bf329 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -65,7 +65,12 @@ def __init__( input_registry) # Detokenizer (converts EngineCoreOutputs --> RequestOutput). - self.detokenizer = Detokenizer(vllm_config.model_config.tokenizer) + self.detokenizer = Detokenizer( + tokenizer_name=vllm_config.model_config.tokenizer, + tokenizer_mode=vllm_config.model_config.tokenizer_mode, + trust_remote_code=vllm_config.model_config.trust_remote_code, + revision=vllm_config.model_config.tokenizer_revision, + ) # EngineCore (starts the engine in background process). self.engine_core = EngineCoreClient.make_client( From e691b26f6fae5a3a1c220d15f20de83c7d78ed51 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Mon, 9 Dec 2024 11:44:27 -0500 Subject: [PATCH 10/19] [Core] Require xgrammar >= 0.1.6 (#11021) Signed-off-by: Russell Bryant --- requirements-common.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-common.txt b/requirements-common.txt index 72fb020a82c4e..112528880c0ac 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -19,7 +19,7 @@ prometheus-fastapi-instrumentator >= 7.0.0 tiktoken >= 0.6.0 # Required for DBRX tokenizer lm-format-enforcer >= 0.10.9, < 0.11 outlines >= 0.0.43, < 0.1 -xgrammar >= 0.1.5; platform_machine == "x86_64" +xgrammar >= 0.1.6; platform_machine == "x86_64" typing_extensions >= 4.10 filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317 partial-json-parser # used for parsing partial JSON outputs From aea2fc38c3b31b9a8ea7d1cffb8f37a2da6f6075 Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Tue, 10 Dec 2024 01:24:46 +0800 Subject: [PATCH 11/19] [Platform] Move `async output` check to platform (#10768) Signed-off-by: wangxiyuan --- vllm/config.py | 17 +++-------------- vllm/platforms/cpu.py | 6 +++++- vllm/platforms/cuda.py | 12 +++++++++++- vllm/platforms/hpu.py | 6 +++++- vllm/platforms/interface.py | 11 +++++++++++ vllm/platforms/neuron.py | 6 +++++- vllm/platforms/openvino.py | 6 +++++- vllm/platforms/rocm.py | 12 +++++++++++- vllm/platforms/tpu.py | 6 +++++- vllm/platforms/xpu.py | 6 +++++- 10 files changed, 66 insertions(+), 22 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 7fbe04eaaf4f8..29f0839dcabba 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -513,11 +513,10 @@ def verify_async_output_proc(self, parallel_config, speculative_config, # Reminder: Please update docs/source/usage/compatibility_matrix.rst # If the feature combo become valid - if device_config.device_type not in ("cuda", "tpu", "xpu", "hpu"): + if not current_platform.is_async_output_supported(self.enforce_eager): logger.warning( - "Async output processing is only supported for CUDA, TPU, XPU " - "and HPU." - "Disabling it for other platforms.") + "Async output processing is not supported on the " + "current platform type %s.", current_platform.device_type) self.use_async_output_proc = False return @@ -527,16 +526,6 @@ def verify_async_output_proc(self, parallel_config, speculative_config, self.use_async_output_proc = False return - # Reminder: Please update docs/source/usage/compatibility_matrix.rst - # If the feature combo become valid - if device_config.device_type == "cuda" and self.enforce_eager: - logger.warning( - "To see benefits of async output processing, enable CUDA " - "graph. Since, enforce-eager is enabled, async output " - "processor cannot be used") - self.use_async_output_proc = not self.enforce_eager - return - # Async postprocessor is not necessary with embedding mode # since there is no token generation if self.task == "embedding": diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 680ee74129739..e5142b985d1f2 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import psutil import torch @@ -37,6 +37,10 @@ def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: def get_device_total_memory(cls, device_id: int = 0) -> int: return psutil.virtual_memory().total + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + return False + @classmethod def inference_mode(cls): return torch.no_grad() diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 846a1869da228..edaf377b501df 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -4,7 +4,7 @@ import os from functools import lru_cache, wraps -from typing import TYPE_CHECKING, Callable, List, TypeVar +from typing import TYPE_CHECKING, Callable, List, Optional, TypeVar import pynvml import torch @@ -88,6 +88,16 @@ def get_device_name(cls, device_id: int = 0) -> str: def get_device_total_memory(cls, device_id: int = 0) -> int: raise NotImplementedError + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + if enforce_eager: + logger.warning( + "To see benefits of async output processing, enable CUDA " + "graph. Since, enforce-eager is enabled, async output " + "processor cannot be used") + return False + return True + @classmethod def is_full_nvlink(cls, device_ids: List[int]) -> bool: raise NotImplementedError diff --git a/vllm/platforms/hpu.py b/vllm/platforms/hpu.py index 10aaa6d54962c..7f22bee3eaa74 100644 --- a/vllm/platforms/hpu.py +++ b/vllm/platforms/hpu.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch @@ -20,6 +20,10 @@ class HpuPlatform(Platform): def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: return _Backend.HPU_ATTN + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + return True + @staticmethod def inference_mode(): return torch.no_grad() diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 0be7df7941b8b..db06d2c18e681 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -6,11 +6,15 @@ import numpy as np import torch +from vllm.logger import init_logger + if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None +logger = init_logger(__name__) + class _Backend(enum.Enum): FLASH_ATTN = enum.auto() @@ -147,6 +151,13 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: """Get the total memory of a device in bytes.""" raise NotImplementedError + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + """ + Check if the current platform supports async output. + """ + raise NotImplementedError + @classmethod def inference_mode(cls): """A device-specific wrapper of `torch.inference_mode`. diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py index 87655ea198303..1e5c4bddfa24f 100644 --- a/vllm/platforms/neuron.py +++ b/vllm/platforms/neuron.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from .interface import Platform, PlatformEnum @@ -18,6 +18,10 @@ class NeuronPlatform(Platform): def get_device_name(cls, device_id: int = 0) -> str: return "neuron" + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + return False + @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: parallel_config = vllm_config.parallel_config diff --git a/vllm/platforms/openvino.py b/vllm/platforms/openvino.py index 29b61e955d9ab..e0f8e8b4b49fe 100644 --- a/vllm/platforms/openvino.py +++ b/vllm/platforms/openvino.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch @@ -37,6 +37,10 @@ def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: def get_device_name(self, device_id: int = 0) -> str: return "openvino" + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + return False + @classmethod def inference_mode(self): return torch.inference_mode(mode=True) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 3c14fbc179f69..66674e3ebe91f 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -1,6 +1,6 @@ import os from functools import lru_cache -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch @@ -72,6 +72,16 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: device_props = torch.cuda.get_device_properties(device_id) return device_props.total_memory + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + if enforce_eager: + logger.warning( + "To see benefits of async output processing, enable CUDA " + "graph. Since, enforce-eager is enabled, async output " + "processor cannot be used") + return False + return True + @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: parallel_config = vllm_config.parallel_config diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index b138f7e1c54c5..10d874349f36b 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch @@ -35,6 +35,10 @@ def get_device_name(cls, device_id: int = 0) -> str: def get_device_total_memory(cls, device_id: int = 0) -> int: raise NotImplementedError + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + return True + @classmethod def inference_mode(cls): return torch.no_grad() diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 9665786f4c499..11dbd04d55671 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch @@ -41,6 +41,10 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: device_props = torch.xpu.get_device_properties(device_id) return device_props.total_memory + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + return True + @staticmethod def inference_mode(): return torch.no_grad() From 25b79d9fd38e2c53ce281be23241d8939ec7320c Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Mon, 9 Dec 2024 12:33:41 -0500 Subject: [PATCH 12/19] [V1] Input Batch Relocation (#10962) Signed-off-by: Varun Sundar Rabindranath Co-authored-by: Varun Sundar Rabindranath --- vllm/v1/worker/gpu_input_batch.py | 280 +++++++++++++++++++++++++++++ vllm/v1/worker/gpu_model_runner.py | 273 +--------------------------- 2 files changed, 283 insertions(+), 270 deletions(-) create mode 100644 vllm/v1/worker/gpu_input_batch.py diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py new file mode 100644 index 0000000000000..457784bb0287c --- /dev/null +++ b/vllm/v1/worker/gpu_input_batch.py @@ -0,0 +1,280 @@ +# Datastructures defining an input batch + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Set + +import numpy as np +import torch + +from vllm.multimodal import MultiModalKwargs +from vllm.sampling_params import SamplingParams, SamplingType +from vllm.v1.sample.metadata import SamplingMetadata + +if TYPE_CHECKING: + from vllm.multimodal.inputs import PlaceholderRange + + +@dataclass +class CachedRequestState: + + req_id: str + prompt_token_ids: List[int] + prompt: Optional[str] + mm_inputs: List[MultiModalKwargs] + mm_positions: List["PlaceholderRange"] + sampling_params: SamplingParams + generator: Optional[torch.Generator] + + block_ids: List[int] + num_computed_tokens: int + output_token_ids: List[int] + + @property + def num_tokens(self) -> int: + return len(self.prompt_token_ids) + len(self.output_token_ids) + + +class InputBatch: + + def __init__( + self, + max_num_reqs: int, + max_model_len: int, + max_num_blocks_per_req: int, + device: torch.device, + pin_memory: bool, + ): + self.max_num_reqs = max_num_reqs + self.max_model_len = max_model_len + self.max_num_blocks_per_req = max_num_blocks_per_req + self.device = device + self.pin_memory = pin_memory + + self.req_ids: List[Optional[str]] = [None] * max_num_reqs + self.req_id_to_index: Dict[str, int] = {} + + self.token_ids_cpu = np.empty((max_num_reqs, max_model_len), + dtype=np.int32) + self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) + + # Attention-related. + self.block_table = torch.zeros((max_num_reqs, max_num_blocks_per_req), + device=self.device, + dtype=torch.int32) + self.block_table_cpu_tensor = torch.zeros( + (max_num_reqs, max_num_blocks_per_req), + device="cpu", + dtype=torch.int32, + pin_memory=pin_memory, + ) + self.block_table_cpu = self.block_table_cpu_tensor.numpy() + + # Sampling-related. + self.temperature = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + self.temperature_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=pin_memory) + self.temperature_cpu = self.temperature_cpu_tensor.numpy() + self.greedy_reqs: Set[str] = set() + self.random_reqs: Set[str] = set() + + self.top_p = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + self.top_p_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=pin_memory) + self.top_p_cpu = self.top_p_cpu_tensor.numpy() + self.top_p_reqs: Set[str] = set() + + self.top_k = torch.empty((max_num_reqs, ), + dtype=torch.int32, + device=device) + self.top_k_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.int32, + device="cpu", + pin_memory=pin_memory) + self.top_k_cpu = self.top_k_cpu_tensor.numpy() + self.top_k_reqs: Set[str] = set() + + # req_index -> generator + self.generators: Dict[int, torch.Generator] = {} + + self.num_logprobs: Dict[str, int] = {} + self.prompt_logprob_reqs: Set[str] = set() + + def add_request( + self, + request: "CachedRequestState", + req_index: Optional[int] = None, + ) -> None: + if req_index is None: + req_index = self.num_reqs + assert req_index < self.max_num_reqs + + req_id = request.req_id + self.req_ids[req_index] = req_id + self.req_id_to_index[req_id] = req_index + + # Copy the prompt token ids and output token ids. + num_prompt_tokens = len(request.prompt_token_ids) + self.token_ids_cpu[ + req_index, :num_prompt_tokens] = request.prompt_token_ids + start_idx = num_prompt_tokens + end_idx = start_idx + len(request.output_token_ids) + self.token_ids_cpu[req_index, + start_idx:end_idx] = request.output_token_ids + + self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens + num_blocks = len(request.block_ids) + self.block_table_cpu[req_index, :num_blocks] = request.block_ids + + sampling_params = request.sampling_params + self.temperature_cpu[req_index] = sampling_params.temperature + if sampling_params.sampling_type == SamplingType.GREEDY: + self.greedy_reqs.add(req_id) + else: + self.random_reqs.add(req_id) + + self.top_p_cpu[req_index] = sampling_params.top_p + if sampling_params.top_p < 1: + self.top_p_reqs.add(req_id) + self.top_k_cpu[req_index] = sampling_params.top_k + if sampling_params.top_k > 0: + self.top_k_reqs.add(req_id) + + self.generators[req_index] = request.generator + + num_logprobs = sampling_params.logprobs + if num_logprobs is not None and num_logprobs > 0: + self.num_logprobs[req_id] = num_logprobs + if sampling_params.prompt_logprobs: + self.prompt_logprob_reqs.add(req_id) + + def remove_request(self, req_id: str) -> Optional[int]: + req_index = self.req_id_to_index.pop(req_id, None) + if req_index is None: + return None + self.req_ids[req_index] = None + + self.greedy_reqs.discard(req_id) + self.random_reqs.discard(req_id) + self.top_p_reqs.discard(req_id) + self.top_k_reqs.discard(req_id) + self.generators.pop(req_index, None) + self.num_logprobs.pop(req_id, None) + self.prompt_logprob_reqs.discard(req_id) + return req_index + + def clear(self) -> None: + self.req_ids = [None] * self.max_num_reqs + self.req_id_to_index.clear() + self.greedy_reqs.clear() + self.random_reqs.clear() + self.top_p_reqs.clear() + self.top_k_reqs.clear() + self.generators.clear() + self.num_logprobs.clear() + self.prompt_logprob_reqs.clear() + + def condense(self, empty_req_indices: List[int]) -> None: + if self.num_reqs == 0: + # The batched states are empty. + return + + # NOTE(woosuk): This function assumes that the empty_req_indices + # is sorted in descending order. + last_req_index = self.num_reqs + len(empty_req_indices) - 1 + while empty_req_indices: + # Find the largest non-empty index. + while last_req_index in empty_req_indices: + last_req_index -= 1 + + # Find the smallest empty index. + empty_index = empty_req_indices.pop() + if empty_index >= last_req_index: + break + + # Swap the states. + req_id = self.req_ids[last_req_index] + self.req_ids[empty_index] = req_id + self.req_ids[last_req_index] = None + self.req_id_to_index[req_id] = empty_index + + # TODO(woosuk): Optimize the copy of token_ids_cpu and + # block_table_cpu. + self.token_ids_cpu[empty_index] = self.token_ids_cpu[ + last_req_index] + self.num_computed_tokens_cpu[ + empty_index] = self.num_computed_tokens_cpu[last_req_index] + self.block_table_cpu[empty_index] = self.block_table_cpu[ + last_req_index] + self.temperature_cpu[empty_index] = self.temperature_cpu[ + last_req_index] + self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] + self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] + generator = self.generators.pop(last_req_index, None) + if generator is not None: + self.generators[empty_index] = generator + + # Decrement last_req_index since it is now empty. + last_req_index -= 1 + + def make_sampling_metadata( + self, + skip_copy: bool = False, + ) -> SamplingMetadata: + if not skip_copy: + self.temperature[:self.num_reqs].copy_( + self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True) + self.top_p[:self.num_reqs].copy_( + self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True) + self.top_k[:self.num_reqs].copy_( + self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True) + return SamplingMetadata( + temperature=self.temperature[:self.num_reqs], + all_greedy=self.all_greedy, + all_random=self.all_random, + top_p=self.top_p[:self.num_reqs], + top_k=self.top_k[:self.num_reqs], + no_top_p=self.no_top_p, + no_top_k=self.no_top_k, + generators=self.generators, + max_num_logprobs=self.max_num_logprobs, + ) + + @property + def num_reqs(self) -> int: + return len(self.req_id_to_index) + + @property + def all_greedy(self) -> bool: + return len(self.random_reqs) == 0 + + @property + def all_random(self) -> bool: + return len(self.greedy_reqs) == 0 + + @property + def no_top_p(self) -> bool: + return len(self.top_p_reqs) == 0 + + @property + def no_top_k(self) -> bool: + return len(self.top_k_reqs) == 0 + + @property + def max_num_logprobs(self) -> int: + return max(self.num_logprobs.values()) if self.num_logprobs else 0 + + @property + def no_logprob(self) -> bool: + return len(self.num_logprobs) == 0 + + @property + def no_prompt_logprob(self) -> bool: + return len(self.prompt_logprob_reqs) == 0 diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e8d964a722f60..7f95be06188e3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,7 +1,6 @@ import gc import time -from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple import numpy as np import torch @@ -15,16 +14,16 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.multimodal import MultiModalKwargs -from vllm.sampling_params import SamplingParams, SamplingType +from vllm.sampling_params import SamplingType from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, cdiv, is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, FlashAttentionMetadata) from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch if TYPE_CHECKING: - from vllm.multimodal.inputs import PlaceholderRange from vllm.v1.core.scheduler import SchedulerOutput logger = init_logger(__name__) @@ -609,269 +608,3 @@ def _get_padded_batch_size(self, batch_size: int) -> Optional[int]: if batch_size <= size: return size return None - - -@dataclass -class CachedRequestState: - - req_id: str - prompt_token_ids: List[int] - prompt: Optional[str] - mm_inputs: List[MultiModalKwargs] - mm_positions: List["PlaceholderRange"] - sampling_params: SamplingParams - generator: Optional[torch.Generator] - - block_ids: List[int] - num_computed_tokens: int - output_token_ids: List[int] - - @property - def num_tokens(self) -> int: - return len(self.prompt_token_ids) + len(self.output_token_ids) - - -class InputBatch: - - def __init__( - self, - max_num_reqs: int, - max_model_len: int, - max_num_blocks_per_req: int, - device: torch.device, - pin_memory: bool, - ): - self.max_num_reqs = max_num_reqs - self.max_model_len = max_model_len - self.max_num_blocks_per_req = max_num_blocks_per_req - self.device = device - self.pin_memory = pin_memory - - self.req_ids: List[Optional[str]] = [None] * max_num_reqs - self.req_id_to_index: Dict[str, int] = {} - - self.token_ids_cpu = np.empty((max_num_reqs, max_model_len), - dtype=np.int32) - self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) - - # Attention-related. - self.block_table = torch.zeros((max_num_reqs, max_num_blocks_per_req), - device=self.device, - dtype=torch.int32) - self.block_table_cpu_tensor = torch.zeros( - (max_num_reqs, max_num_blocks_per_req), - device="cpu", - dtype=torch.int32, - pin_memory=pin_memory, - ) - self.block_table_cpu = self.block_table_cpu_tensor.numpy() - - # Sampling-related. - self.temperature = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) - self.temperature_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) - self.temperature_cpu = self.temperature_cpu_tensor.numpy() - self.greedy_reqs: Set[str] = set() - self.random_reqs: Set[str] = set() - - self.top_p = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) - self.top_p_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) - self.top_p_cpu = self.top_p_cpu_tensor.numpy() - self.top_p_reqs: Set[str] = set() - - self.top_k = torch.empty((max_num_reqs, ), - dtype=torch.int32, - device=device) - self.top_k_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.int32, - device="cpu", - pin_memory=pin_memory) - self.top_k_cpu = self.top_k_cpu_tensor.numpy() - self.top_k_reqs: Set[str] = set() - - # req_index -> generator - self.generators: Dict[int, torch.Generator] = {} - - self.num_logprobs: Dict[str, int] = {} - self.prompt_logprob_reqs: Set[str] = set() - - def add_request( - self, - request: "CachedRequestState", - req_index: Optional[int] = None, - ) -> None: - if req_index is None: - req_index = self.num_reqs - assert req_index < self.max_num_reqs - - req_id = request.req_id - self.req_ids[req_index] = req_id - self.req_id_to_index[req_id] = req_index - - # Copy the prompt token ids and output token ids. - num_prompt_tokens = len(request.prompt_token_ids) - self.token_ids_cpu[ - req_index, :num_prompt_tokens] = request.prompt_token_ids - start_idx = num_prompt_tokens - end_idx = start_idx + len(request.output_token_ids) - self.token_ids_cpu[req_index, - start_idx:end_idx] = request.output_token_ids - - self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens - num_blocks = len(request.block_ids) - self.block_table_cpu[req_index, :num_blocks] = request.block_ids - - sampling_params = request.sampling_params - self.temperature_cpu[req_index] = sampling_params.temperature - if sampling_params.sampling_type == SamplingType.GREEDY: - self.greedy_reqs.add(req_id) - else: - self.random_reqs.add(req_id) - - self.top_p_cpu[req_index] = sampling_params.top_p - if sampling_params.top_p < 1: - self.top_p_reqs.add(req_id) - self.top_k_cpu[req_index] = sampling_params.top_k - if sampling_params.top_k > 0: - self.top_k_reqs.add(req_id) - - self.generators[req_index] = request.generator - - num_logprobs = sampling_params.logprobs - if num_logprobs is not None and num_logprobs > 0: - self.num_logprobs[req_id] = num_logprobs - if sampling_params.prompt_logprobs: - self.prompt_logprob_reqs.add(req_id) - - def remove_request(self, req_id: str) -> Optional[int]: - req_index = self.req_id_to_index.pop(req_id, None) - if req_index is None: - return None - self.req_ids[req_index] = None - - self.greedy_reqs.discard(req_id) - self.random_reqs.discard(req_id) - self.top_p_reqs.discard(req_id) - self.top_k_reqs.discard(req_id) - self.generators.pop(req_index, None) - self.num_logprobs.pop(req_id, None) - self.prompt_logprob_reqs.discard(req_id) - return req_index - - def clear(self) -> None: - self.req_ids = [None] * self.max_num_reqs - self.req_id_to_index.clear() - self.greedy_reqs.clear() - self.random_reqs.clear() - self.top_p_reqs.clear() - self.top_k_reqs.clear() - self.generators.clear() - self.num_logprobs.clear() - self.prompt_logprob_reqs.clear() - - def condense(self, empty_req_indices: List[int]) -> None: - if self.num_reqs == 0: - # The batched states are empty. - return - - # NOTE(woosuk): This function assumes that the empty_req_indices - # is sorted in descending order. - last_req_index = self.num_reqs + len(empty_req_indices) - 1 - while empty_req_indices: - # Find the largest non-empty index. - while last_req_index in empty_req_indices: - last_req_index -= 1 - - # Find the smallest empty index. - empty_index = empty_req_indices.pop() - if empty_index >= last_req_index: - break - - # Swap the states. - req_id = self.req_ids[last_req_index] - self.req_ids[empty_index] = req_id - self.req_ids[last_req_index] = None - self.req_id_to_index[req_id] = empty_index - - # TODO(woosuk): Optimize the copy of token_ids_cpu and - # block_table_cpu. - self.token_ids_cpu[empty_index] = self.token_ids_cpu[ - last_req_index] - self.num_computed_tokens_cpu[ - empty_index] = self.num_computed_tokens_cpu[last_req_index] - self.block_table_cpu[empty_index] = self.block_table_cpu[ - last_req_index] - self.temperature_cpu[empty_index] = self.temperature_cpu[ - last_req_index] - self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] - self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] - generator = self.generators.pop(last_req_index, None) - if generator is not None: - self.generators[empty_index] = generator - - # Decrement last_req_index since it is now empty. - last_req_index -= 1 - - def make_sampling_metadata( - self, - skip_copy: bool = False, - ) -> SamplingMetadata: - if not skip_copy: - self.temperature[:self.num_reqs].copy_( - self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True) - self.top_p[:self.num_reqs].copy_( - self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True) - self.top_k[:self.num_reqs].copy_( - self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True) - return SamplingMetadata( - temperature=self.temperature[:self.num_reqs], - all_greedy=self.all_greedy, - all_random=self.all_random, - top_p=self.top_p[:self.num_reqs], - top_k=self.top_k[:self.num_reqs], - no_top_p=self.no_top_p, - no_top_k=self.no_top_k, - generators=self.generators, - max_num_logprobs=self.max_num_logprobs, - ) - - @property - def num_reqs(self) -> int: - return len(self.req_id_to_index) - - @property - def all_greedy(self) -> bool: - return len(self.random_reqs) == 0 - - @property - def all_random(self) -> bool: - return len(self.greedy_reqs) == 0 - - @property - def no_top_p(self) -> bool: - return len(self.top_p_reqs) == 0 - - @property - def no_top_k(self) -> bool: - return len(self.top_k_reqs) == 0 - - @property - def max_num_logprobs(self) -> int: - return max(self.num_logprobs.values()) if self.num_logprobs else 0 - - @property - def no_logprob(self) -> bool: - return len(self.num_logprobs) == 0 - - @property - def no_prompt_logprob(self) -> bool: - return len(self.prompt_logprob_reqs) == 0 From bdd0abf9d061dbfc68e24e2328475f276d21f25f Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 9 Dec 2024 18:08:29 +0000 Subject: [PATCH 13/19] removed VLLM_USE_V1 checks Signed-off-by: Andrew Feldman --- tests/v1/sample/test_logprobs.py | 15 ++++++--------- tests/v1/utils.py | 7 ------- 2 files changed, 6 insertions(+), 16 deletions(-) delete mode 100644 tests/v1/utils.py diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py index 1a1d361170187..0d8031f05e8d1 100644 --- a/tests/v1/sample/test_logprobs.py +++ b/tests/v1/sample/test_logprobs.py @@ -5,10 +5,9 @@ import torch from tests.kernels.utils import override_backend_env_variable -from tests.v1.samplers.utils import ( +from tests.v1.sample.utils import ( assert_incr_detok_str_matches_non_incr_detok_str, compute_correct_cumulative_logprob, get_test_batch) -from tests.v1.utils import assert_vllm_use_v1 from vllm import SamplingParams from ...conftest import VllmRunner @@ -27,7 +26,6 @@ def _test_case_get_logprobs_and_prompt_logprobs( example_prompts, monkeypatch, ) -> None: - assert_vllm_use_v1() test_prompts = example_prompts override_backend_env_variable(monkeypatch, "FLASH_ATTN") @@ -287,7 +285,6 @@ def test_max_logprobs(monkeypatch): Args: monkeypatch """ - assert_vllm_use_v1() override_backend_env_variable(monkeypatch, "FLASH_ATTN") runner = VllmRunner("facebook/opt-125m", max_logprobs=1) @@ -305,12 +302,12 @@ def test_none_logprobs(vllm_runner, model, example_prompts, monkeypatch): """Engine should return `logprobs` and `prompt_logprobs` as `None` Args: - vllm_runner - model - example_prompts - monkeypatch + vllm_runner: vLLM engine runner fixture + model: model name + example_prompts: list of example prompts (test fixture) + monkeypatch: supports editing env vars and rolling back changes + after the test """ - assert_vllm_use_v1() override_backend_env_variable(monkeypatch, "FLASH_ATTN") max_num_seqs = 256 diff --git a/tests/v1/utils.py b/tests/v1/utils.py deleted file mode 100644 index db9193a487c95..0000000000000 --- a/tests/v1/utils.py +++ /dev/null @@ -1,7 +0,0 @@ -"""V1 vLLM engine test utils""" -import os - - -def assert_vllm_use_v1(): - if os.getenv("VLLM_USE_V1") != "1": - raise OSError("Test requires VLLM_USE_V1=\"1\"") From 1fc981eac6e6f521f64489745aaeec9c22654b43 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 9 Dec 2024 18:15:20 +0000 Subject: [PATCH 14/19] revert logprobs name changes Signed-off-by: Andrew Feldman --- examples/llm_engine_example.py | 4 +-- examples/lora_with_quantization_inference.py | 16 +++++----- examples/multilora_inference.py | 12 +++---- tests/conftest.py | 14 ++++----- tests/engine/test_skip_tokenizer_init.py | 3 +- .../decoder_only/language/test_mistral.py | 4 +-- .../vision_language/test_pixtral.py | 4 +-- tests/samplers/test_logits_processor.py | 4 +-- tests/samplers/test_logprobs.py | 24 +++++++------- tests/samplers/test_ranks.py | 15 +++++---- tests/samplers/test_sampler.py | 4 +-- tests/spec_decode/e2e/conftest.py | 4 +-- tests/spec_decode/e2e/test_logprobs.py | 2 +- tests/tokenization/test_detokenize.py | 8 ++--- tests/v1/sample/test_logprobs.py | 17 +++++----- vllm/engine/llm_engine.py | 8 ++--- vllm/engine/protocol.py | 2 +- vllm/entrypoints/llm.py | 3 +- vllm/model_executor/layers/sampler.py | 17 +++++----- vllm/model_executor/sampling_metadata.py | 13 ++++---- vllm/outputs.py | 2 +- vllm/sampling_params.py | 31 ++++++++----------- vllm/spec_decode/spec_decode_worker.py | 5 ++- vllm/spec_decode/util.py | 3 +- vllm/v1/engine/processor.py | 11 +++---- vllm/v1/request.py | 4 +-- vllm/v1/worker/gpu_input_batch.py | 6 ++-- vllm/worker/hpu_model_runner.py | 8 ++--- vllm/worker/model_runner.py | 8 ++--- vllm/worker/multi_step_model_runner.py | 8 ++--- vllm/worker/tpu_model_runner.py | 4 +-- 31 files changed, 122 insertions(+), 146 deletions(-) diff --git a/examples/llm_engine_example.py b/examples/llm_engine_example.py index dc87ef3df1ce2..60d894aae9692 100644 --- a/examples/llm_engine_example.py +++ b/examples/llm_engine_example.py @@ -9,9 +9,7 @@ def create_test_prompts() -> List[Tuple[str, SamplingParams]]: """Create a list of test prompts with their sampling parameters.""" return [ ("A robot may not injure a human being", - SamplingParams(temperature=0.0, - request_sample_logprobs=1, - request_prompt_logprobs=1)), + SamplingParams(temperature=0.0, logprobs=1, prompt_logprobs=1)), ("To be or not to be,", SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)), ("What is the meaning of life?", diff --git a/examples/lora_with_quantization_inference.py b/examples/lora_with_quantization_inference.py index ac2cd90ec7ceb..0c454ea50f665 100644 --- a/examples/lora_with_quantization_inference.py +++ b/examples/lora_with_quantization_inference.py @@ -22,26 +22,26 @@ def create_test_prompts( # this is an example of using quantization without LoRA ("My name is", SamplingParams(temperature=0.0, - request_sample_logprobs=1, - request_prompt_logprobs=1, + logprobs=1, + prompt_logprobs=1, max_tokens=128), None), # the next three examples use quantization with LoRA ("my name is", SamplingParams(temperature=0.0, - request_sample_logprobs=1, - request_prompt_logprobs=1, + logprobs=1, + prompt_logprobs=1, max_tokens=128), LoRARequest("lora-test-1", 1, lora_path)), ("The capital of USA is", SamplingParams(temperature=0.0, - request_sample_logprobs=1, - request_prompt_logprobs=1, + logprobs=1, + prompt_logprobs=1, max_tokens=128), LoRARequest("lora-test-2", 1, lora_path)), ("The capital of France is", SamplingParams(temperature=0.0, - request_sample_logprobs=1, - request_prompt_logprobs=1, + logprobs=1, + prompt_logprobs=1, max_tokens=128), LoRARequest("lora-test-3", 1, lora_path)), ] diff --git a/examples/multilora_inference.py b/examples/multilora_inference.py index 904bb6764b2e5..043220d979c3c 100644 --- a/examples/multilora_inference.py +++ b/examples/multilora_inference.py @@ -27,8 +27,8 @@ def create_test_prompts( return [ ("A robot may not injure a human being", SamplingParams(temperature=0.0, - request_sample_logprobs=1, - request_prompt_logprobs=1, + logprobs=1, + prompt_logprobs=1, max_tokens=128), None), ("To be or not to be,", SamplingParams(temperature=0.8, @@ -38,16 +38,16 @@ def create_test_prompts( ( "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 SamplingParams(temperature=0.0, - request_sample_logprobs=1, - request_prompt_logprobs=1, + logprobs=1, + prompt_logprobs=1, max_tokens=128, stop_token_ids=[32003]), LoRARequest("sql-lora", 1, lora_path)), ( "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 SamplingParams(temperature=0.0, - request_sample_logprobs=1, - request_prompt_logprobs=1, + logprobs=1, + prompt_logprobs=1, max_tokens=128, stop_token_ids=[32003]), LoRARequest("sql-lora2", 2, lora_path)), diff --git a/tests/conftest.py b/tests/conftest.py index 61015117a9654..d6be8f5b00af8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -794,7 +794,7 @@ def generate_w_logprobs( self._final_steps_generate_w_logprobs(req_outputs)) # Omit prompt logprobs if not required by sampling params return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs] - if sampling_params.request_prompt_logprobs is None else + if sampling_params.prompt_logprobs is None else toks_str_logsprobs_prompt_logprobs) def generate_encoder_decoder_w_logprobs( @@ -807,14 +807,14 @@ def generate_encoder_decoder_w_logprobs( Logprobs generation for vLLM encoder/decoder models ''' - assert sampling_params.request_sample_logprobs is not None + assert sampling_params.logprobs is not None req_outputs = self.model.generate(encoder_decoder_prompts, sampling_params=sampling_params) toks_str_logsprobs_prompt_logprobs = ( self._final_steps_generate_w_logprobs(req_outputs)) # Omit prompt logprobs if not required by sampling params return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs] - if sampling_params.request_prompt_logprobs is None else + if sampling_params.prompt_logprobs is None else toks_str_logsprobs_prompt_logprobs) def generate_greedy( @@ -850,8 +850,8 @@ def generate_greedy_logprobs( greedy_logprobs_params = SamplingParams( temperature=0.0, max_tokens=max_tokens, - request_sample_logprobs=num_logprobs, - request_prompt_logprobs=num_prompt_logprobs, + logprobs=num_logprobs, + prompt_logprobs=num_prompt_logprobs, stop_token_ids=stop_token_ids, stop=stop) @@ -872,8 +872,8 @@ def generate_encoder_decoder_greedy_logprobs( greedy_logprobs_params = SamplingParams( temperature=0.0, max_tokens=max_tokens, - request_sample_logprobs=num_logprobs, - request_prompt_logprobs=(num_prompt_logprobs), + logprobs=num_logprobs, + prompt_logprobs=(num_prompt_logprobs), ) ''' Greedy logprobs generation for vLLM encoder/decoder models diff --git a/tests/engine/test_skip_tokenizer_init.py b/tests/engine/test_skip_tokenizer_init.py index 09c9ed1474880..b8818af5614cf 100644 --- a/tests/engine/test_skip_tokenizer_init.py +++ b/tests/engine/test_skip_tokenizer_init.py @@ -10,8 +10,7 @@ def test_skip_tokenizer_initialization(model: str): # of tokenizer and detokenizer. The generated output is expected to contain # token ids. llm = LLM(model=model, skip_tokenizer_init=True) - sampling_params = SamplingParams(request_prompt_logprobs=True, - detokenize=True) + sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True) with pytest.raises(ValueError, match="cannot pass text prompts when"): llm.generate("abc", sampling_params) diff --git a/tests/models/decoder_only/language/test_mistral.py b/tests/models/decoder_only/language/test_mistral.py index 68b95fb800bcb..99b5d5694f9f7 100644 --- a/tests/models/decoder_only/language/test_mistral.py +++ b/tests/models/decoder_only/language/test_mistral.py @@ -24,9 +24,7 @@ # "mistralai/Mistral-Nemo-Instruct-2407" ] -SAMPLING_PARAMS = SamplingParams(max_tokens=512, - temperature=0.0, - request_sample_logprobs=5) +SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5) SYMBOLIC_LANG_PROMPTS = [ "勇敢な船乗りについての詩を書く", # japanese "寫一首關於勇敢的水手的詩", # chinese diff --git a/tests/models/decoder_only/vision_language/test_pixtral.py b/tests/models/decoder_only/vision_language/test_pixtral.py index 492cafa8a18a7..90c0fab99054c 100644 --- a/tests/models/decoder_only/vision_language/test_pixtral.py +++ b/tests/models/decoder_only/vision_language/test_pixtral.py @@ -116,9 +116,7 @@ def _create_engine_inputs_hf(urls: List[str]) -> TextPrompt: _create_engine_inputs(IMG_URLS), ] -SAMPLING_PARAMS = SamplingParams(max_tokens=512, - temperature=0.0, - request_sample_logprobs=5) +SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5) LIMIT_MM_PER_PROMPT = dict(image=4) MAX_MODEL_LEN = [8192, 65536] diff --git a/tests/samplers/test_logits_processor.py b/tests/samplers/test_logits_processor.py index 646ef56f23a7b..2979470120710 100644 --- a/tests/samplers/test_logits_processor.py +++ b/tests/samplers/test_logits_processor.py @@ -29,7 +29,7 @@ def pick_vllm(token_ids, logits): params_with_logprobs = SamplingParams( logits_processors=[pick_vllm], - request_prompt_logprobs=3, + prompt_logprobs=3, max_tokens=max_tokens, ) @@ -43,7 +43,7 @@ def pick_vllm(token_ids, logits): vllm_model.model._add_request( example_prompts[1], params=SamplingParams( - request_prompt_logprobs=3, + prompt_logprobs=3, max_tokens=max_tokens, ), ) diff --git a/tests/samplers/test_logprobs.py b/tests/samplers/test_logprobs.py index dcd75c7539fe2..c07c71e38233f 100644 --- a/tests/samplers/test_logprobs.py +++ b/tests/samplers/test_logprobs.py @@ -49,12 +49,11 @@ def test_get_prompt_logprobs( max_num_batched_tokens=max_num_batched_tokens, max_num_seqs=max_num_seqs, ) as vllm_model: - vllm_sampling_params = SamplingParams( - max_tokens=max_tokens, - request_sample_logprobs=num_top_logprobs, - request_prompt_logprobs=num_top_logprobs, - temperature=0.0, - detokenize=detokenize) + vllm_sampling_params = SamplingParams(max_tokens=max_tokens, + logprobs=num_top_logprobs, + prompt_logprobs=num_top_logprobs, + temperature=0.0, + detokenize=detokenize) vllm_results = vllm_model.model.generate( example_prompts, sampling_params=vllm_sampling_params) @@ -132,11 +131,11 @@ def test_get_prompt_logprobs( def test_max_logprobs(): runner = VllmRunner("facebook/opt-125m", max_logprobs=1) - vllm_sampling_params = SamplingParams(request_sample_logprobs=1) + vllm_sampling_params = SamplingParams(logprobs=1) # should pass runner.generate(["Hello world"], sampling_params=vllm_sampling_params) - bad_sampling_params = SamplingParams(request_sample_logprobs=2) + bad_sampling_params = SamplingParams(logprobs=2) with pytest.raises(ValueError): runner.generate(["Hello world"], sampling_params=bad_sampling_params) @@ -161,11 +160,10 @@ def test_none_logprobs(vllm_runner, model, chunked_prefill_token_size: int, max_num_batched_tokens=max_num_batched_tokens, max_num_seqs=max_num_seqs, ) as vllm_model: - sampling_params_logprobs_none = SamplingParams( - max_tokens=max_tokens, - request_sample_logprobs=None, - temperature=0.0, - detokenize=detokenize) + sampling_params_logprobs_none = SamplingParams(max_tokens=max_tokens, + logprobs=None, + temperature=0.0, + detokenize=detokenize) results_logprobs_none = vllm_model.model.generate( example_prompts, sampling_params=sampling_params_logprobs_none) diff --git a/tests/samplers/test_ranks.py b/tests/samplers/test_ranks.py index ba41fc615d14a..ed2fee1ae252e 100644 --- a/tests/samplers/test_ranks.py +++ b/tests/samplers/test_ranks.py @@ -25,18 +25,17 @@ def test_ranks( temperature=0.0, top_p=1.0, max_tokens=max_tokens, - request_sample_logprobs=num_top_logprobs, - request_prompt_logprobs=num_prompt_logprobs) + logprobs=num_top_logprobs, + prompt_logprobs=num_prompt_logprobs) vllm_results = vllm_model.generate_w_logprobs(example_prompts, vllm_sampling_params) ## Test non-greedy logprobs ranks - sampling_params = SamplingParams( - temperature=1.0, - top_p=1.0, - max_tokens=max_tokens, - request_sample_logprobs=num_top_logprobs, - request_prompt_logprobs=num_prompt_logprobs) + sampling_params = SamplingParams(temperature=1.0, + top_p=1.0, + max_tokens=max_tokens, + logprobs=num_top_logprobs, + prompt_logprobs=num_prompt_logprobs) res = vllm_model.generate_w_logprobs(example_prompts, sampling_params) for result in vllm_results: diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 4c1dfb48fbe6f..28c34064f670c 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -176,7 +176,7 @@ def create_sampling_params(min_tokens, max_tokens=9999, # keep higher than max of min_tokens stop_token_ids=stop_token_ids, # requesting prompt_logprobs changes the structure of `logits` - request_prompt_logprobs=prompt_logprobs, + prompt_logprobs=prompt_logprobs, ) sampling_params.all_stop_token_ids.add(eos_token_id) return sampling_params @@ -395,7 +395,7 @@ def run_test_case(*, expected_penalization: List[bool], seq_lens.append(prompt_len) assert sgm.sampling_params is not None - if sgm.sampling_params.request_prompt_logprobs: + if sgm.sampling_params.prompt_logprobs: # with prompt_logprobs each token in the prompt has a row in # logits num_rows = prompt_len diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index 39a9dab2b9f11..b9cb3858c0068 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -196,8 +196,8 @@ def run_equality_correctness_test( max_tokens=max_output_len, seed=seed, ignore_eos=ignore_eos, - request_sample_logprobs=logprobs, - request_prompt_logprobs=prompt_logprobs) + logprobs=logprobs, + prompt_logprobs=prompt_logprobs) with vllm_runner(**org_args) as vllm_model: org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params) diff --git a/tests/spec_decode/e2e/test_logprobs.py b/tests/spec_decode/e2e/test_logprobs.py index 7d0d90615bac2..4cfca8b78e79b 100644 --- a/tests/spec_decode/e2e/test_logprobs.py +++ b/tests/spec_decode/e2e/test_logprobs.py @@ -211,7 +211,7 @@ def test_logprobs_temp_1(vllm_runner, common_llm_kwargs, max_tokens=output_len, ignore_eos=True, temperature=temperature, - request_sample_logprobs=logprobs, + logprobs=logprobs, ) sd_args = { diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index 2fce280b188bb..84348cbc0bced 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -201,7 +201,7 @@ def test_decode_sequence_logprobs(complete_sequence: str, skip_special_tokens: bool): """Verify Detokenizer decodes logprobs correctly.""" sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens, - request_sample_logprobs=2) + logprobs=2) # Run sequentially. seq = create_sequence() @@ -234,7 +234,7 @@ def test_decode_prompt_logprobs(complete_sequence_token_ids: List[int], detokenizer: Detokenizer): """Verify Detokenizer decodes prompt logprobs correctly.""" sampling_params = SamplingParams(skip_special_tokens=True, - request_prompt_logprobs=1) + prompt_logprobs=1) # Run sequentially. seq = create_sequence(complete_sequence_token_ids) @@ -294,8 +294,8 @@ def test_decode_prompt_logprobs_chunked_prefill( max_num_seqs=max_num_seqs) as vllm_model: vllm_sampling_params = SamplingParams(max_tokens=10, - request_sample_logprobs=5, - request_prompt_logprobs=5, + logprobs=5, + prompt_logprobs=5, temperature=0.0) vllm_results = vllm_model.model.generate( example_prompts, sampling_params=vllm_sampling_params) diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py index 0d8031f05e8d1..68c72c63786ec 100644 --- a/tests/v1/sample/test_logprobs.py +++ b/tests/v1/sample/test_logprobs.py @@ -62,8 +62,8 @@ def _test_case_get_logprobs_and_prompt_logprobs( # Generate SamplingParams vllm_sampling_params = [ SamplingParams(max_tokens=max_tokens, - request_sample_logprobs=lp, - request_prompt_logprobs=plp, + logprobs=lp, + prompt_logprobs=plp, temperature=0.0, detokenize=detokenize) for lp, plp in logprob_prompt_logprob_list @@ -288,11 +288,11 @@ def test_max_logprobs(monkeypatch): override_backend_env_variable(monkeypatch, "FLASH_ATTN") runner = VllmRunner("facebook/opt-125m", max_logprobs=1) - vllm_sampling_params = SamplingParams(request_sample_logprobs=1) + vllm_sampling_params = SamplingParams(logprobs=1) # should pass runner.generate(["Hello world"], sampling_params=vllm_sampling_params) - bad_sampling_params = SamplingParams(request_sample_logprobs=2) + bad_sampling_params = SamplingParams(logprobs=2) with pytest.raises(ValueError): runner.generate(["Hello world"], sampling_params=bad_sampling_params) @@ -319,11 +319,10 @@ def test_none_logprobs(vllm_runner, model, example_prompts, monkeypatch): max_num_batched_tokens=max_num_batched_tokens, max_num_seqs=max_num_seqs, ) as vllm_model: - sampling_params_logprobs_none = SamplingParams( - max_tokens=max_tokens, - request_sample_logprobs=None, - request_prompt_logprobs=None, - temperature=0.0) + sampling_params_logprobs_none = SamplingParams(max_tokens=max_tokens, + logprobs=None, + prompt_logprobs=None, + temperature=0.0) results_logprobs_none = vllm_model.model.generate( example_prompts, sampling_params=sampling_params_logprobs_none) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 8286e9ce9c70d..560f84a008291 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -847,10 +847,10 @@ def _create_sequence_group_with_sampling( ) -> SequenceGroup: """Creates a SequenceGroup with SamplingParams.""" max_logprobs = self.get_model_config().max_logprobs - if (sampling_params.request_sample_logprobs - and sampling_params.request_sample_logprobs > max_logprobs - ) or (sampling_params.request_prompt_logprobs - and sampling_params.request_prompt_logprobs > max_logprobs): + if (sampling_params.logprobs + and sampling_params.logprobs > max_logprobs) or ( + sampling_params.prompt_logprobs + and sampling_params.prompt_logprobs > max_logprobs): raise ValueError(f"Cannot request more than " f"{max_logprobs} logprobs.") diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index dac592f9f373d..4079de7d36793 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -95,7 +95,7 @@ async def beam_search( tokenizer.eos_token_id, length_penalty) beam_search_params = SamplingParams( - request_sample_logprobs=2 * beam_width, + logprobs=2 * beam_width, max_tokens=1, temperature=temperature, ) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index b2a13143cdb4d..8de30ccd18a11 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -461,8 +461,7 @@ def sort_beams_key(x: BeamSearchSequence) -> float: # generate 2 * beam_width candidates at each step # following the huggingface transformers implementation # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa - beam_search_params = SamplingParams(request_sample_logprobs=2 * - beam_width, + beam_search_params = SamplingParams(logprobs=2 * beam_width, max_tokens=1, temperature=temperature) instances: List[BeamSearchInstance] = [] diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 89156850900f7..c10efefea5471 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -967,9 +967,9 @@ def get_logprobs( # Update indices and tokens for prompt logprobs. if (seq_group.is_prompt - and sampling_params.request_prompt_logprobs is not None): + and sampling_params.prompt_logprobs is not None): largest_num_logprobs = max(largest_num_logprobs, - sampling_params.request_prompt_logprobs) + sampling_params.prompt_logprobs) next_prompt_tokens = _get_next_prompt_tokens(seq_group) query_indices.extend(seq_group.prompt_logprob_indices) next_token_ids.extend(next_prompt_tokens) @@ -986,10 +986,9 @@ def get_logprobs( [query_idx + parent_id for parent_id in parent_seq_ids]) next_token_ids.extend(token_ids) - if sampling_params.request_sample_logprobs is not None: - largest_num_logprobs = max( - largest_num_logprobs, - sampling_params.request_sample_logprobs) + if sampling_params.logprobs is not None: + largest_num_logprobs = max(largest_num_logprobs, + sampling_params.logprobs) assert len(next_token_ids) == len(query_indices) @@ -1071,9 +1070,9 @@ def _get_prompt_logprob_if_needed( # Find prompt logprobs prompt_logprobs: Optional[PromptLogprobs] = None - if is_prompt and sampling_params.request_prompt_logprobs is not None: + if is_prompt and sampling_params.prompt_logprobs is not None: prompt_logprobs = [] - num_logprobs = sampling_params.request_prompt_logprobs + num_logprobs = sampling_params.prompt_logprobs next_prompt_tokens = _get_next_prompt_tokens(seq_group) # Pre-select indexes and create a list. It is faster than calling .item # repetitively. @@ -1128,7 +1127,7 @@ def _get_sampled_logprob_if_needed( ): """Compute the sample logprob if needed.""" seq_ids = seq_group.seq_ids - num_logprobs = seq_group.sampling_params.request_sample_logprobs + num_logprobs = seq_group.sampling_params.logprobs sampled_logprobs: SampleLogprobs = [] next_token_ids, parent_seq_ids = sample_result diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 579319ffdf2ed..a58589bb915ed 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -52,7 +52,7 @@ def do_sample(self): def __post_init__(self): if len(self.prompt_logprob_indices) > 0: - assert self.sampling_params.request_prompt_logprobs is not None + assert self.sampling_params.prompt_logprobs is not None if self.is_prompt: assert self.seq_len is not None assert self.query_len is not None @@ -300,7 +300,7 @@ def _prepare_seq_groups( logits = hidden_states[selected_token_indices] """ - if sampling_params.request_prompt_logprobs is not None: + if sampling_params.prompt_logprobs is not None: selected_token_indices.extend( range(model_output_idx, model_output_idx + prompt_logprob_len)) model_output_idx += prompt_logprob_len @@ -322,7 +322,7 @@ def sample(logits): # sample_indices to find sample indices. """ - if sampling_params.request_prompt_logprobs is not None: + if sampling_params.prompt_logprobs is not None: prompt_logprob_indices.extend( range(logit_idx, logit_idx + prompt_logprob_len)) logit_idx += prompt_logprob_len @@ -426,8 +426,7 @@ def from_sampling_metadata( do_penalties = True is_prompt = seq_group.is_prompt - if (is_prompt - and sampling_params.request_prompt_logprobs is not None): + if (is_prompt and sampling_params.prompt_logprobs is not None): # For tokens in the prompt that we only need to get # their logprobs query_len = seq_group.query_len @@ -456,8 +455,8 @@ def from_sampling_metadata( for seq_group in sampling_metadata.seq_groups: seq_ids = seq_group.seq_ids sampling_params = seq_group.sampling_params - if (seq_group.is_prompt and - sampling_params.request_prompt_logprobs is not None): + if (seq_group.is_prompt + and sampling_params.prompt_logprobs is not None): prefill_len = len(seq_group.prompt_logprob_indices) prompt_tokens.extend( array(VLLM_TOKEN_ID_ARRAY_TYPE) diff --git a/vllm/outputs.py b/vllm/outputs.py index c6d0a31cbd8d8..c412d5ce21571 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -205,7 +205,7 @@ def from_seq_group( # NOTE: We need omit logprobs here explicitly because the sequence # always has the logprobs of the sampled tokens even if the # logprobs are not requested. - include_logprobs = sampling_params.request_sample_logprobs is not None + include_logprobs = sampling_params.logprobs is not None text_buffer_length = sampling_params.output_text_buffer_length delta = sampling_params.output_kind == RequestOutputKind.DELTA diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index cc4d16b3dc6ce..55664c6cf787a 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -186,8 +186,8 @@ class SamplingParams( min_tokens: int = 0 # Number of sample logprobs and prompt logprobs, # respectively, requested - request_sample_logprobs: Optional[int] = None - request_prompt_logprobs: Optional[int] = None + logprobs: Optional[int] = None + prompt_logprobs: Optional[int] = None # NOTE: This parameter is only exposed at the engine level for now. # It is not exposed in the OpenAI API server, as the OpenAI API does # not support returning only a list of token IDs. @@ -270,8 +270,8 @@ def from_optional( ignore_eos=ignore_eos, max_tokens=max_tokens, min_tokens=min_tokens, - request_sample_logprobs=logprobs, - request_prompt_logprobs=prompt_logprobs, + logprobs=logprobs, + prompt_logprobs=prompt_logprobs, detokenize=detokenize, skip_special_tokens=skip_special_tokens, spaces_between_special_tokens=spaces_between_special_tokens, @@ -328,12 +328,9 @@ def __post_init__(self) -> None: else: self.bad_words = list(self.bad_words) - self.request_sample_logprobs = (1 - if self.request_sample_logprobs is True - else self.request_sample_logprobs) - self.request_prompt_logprobs = (1 - if self.request_prompt_logprobs is True - else self.request_prompt_logprobs) + self.logprobs = (1 if self.logprobs is True else self.logprobs) + self.prompt_logprobs = (1 if self.prompt_logprobs is True else + self.prompt_logprobs) # Number of characters to hold back for stop string evaluation # until sequence is finished. @@ -390,14 +387,12 @@ def _verify_args(self) -> None: raise ValueError( f"min_tokens must be less than or equal to " f"max_tokens={self.max_tokens}, got {self.min_tokens}.") - if (self.request_sample_logprobs is not None - and self.request_sample_logprobs < 0): + if (self.logprobs is not None and self.logprobs < 0): raise ValueError(f"logprobs must be non-negative, " - f"got {self.request_sample_logprobs}.") - if (self.request_prompt_logprobs is not None - and self.request_prompt_logprobs < 0): + f"got {self.logprobs}.") + if (self.prompt_logprobs is not None and self.prompt_logprobs < 0): raise ValueError(f"prompt_logprobs must be non-negative, got " - f"{self.request_prompt_logprobs}.") + f"{self.prompt_logprobs}.") if (self.truncate_prompt_tokens is not None and self.truncate_prompt_tokens < 1): raise ValueError(f"truncate_prompt_tokens must be >= 1, " @@ -488,8 +483,8 @@ def __repr__(self) -> str: f"ignore_eos={self.ignore_eos}, " f"max_tokens={self.max_tokens}, " f"min_tokens={self.min_tokens}, " - f"logprobs={self.request_sample_logprobs}, " - f"prompt_logprobs={self.request_prompt_logprobs}, " + f"logprobs={self.logprobs}, " + f"prompt_logprobs={self.prompt_logprobs}, " f"skip_special_tokens={self.skip_special_tokens}, " "spaces_between_special_tokens=" f"{self.spaces_between_special_tokens}, " diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index f76b1bbd7aa07..2689802161987 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -543,9 +543,8 @@ def _serialize_sampler_output_no_logprobs( populated. """ seq_output_prompt_logprobs = [ - seq.is_prompt - and seq.sampling_params.request_prompt_logprobs is not None - and seq.sampling_params.request_prompt_logprobs > 0 + seq.is_prompt and seq.sampling_params.prompt_logprobs is not None + and seq.sampling_params.prompt_logprobs > 0 for seq in execute_model_req.seq_group_metadata_list ] # ignore slots for prompt tokens that are filled with INVALID_TOKEN_ID diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index 1ecc653521ad9..0b6003673578e 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -23,8 +23,7 @@ def get_all_num_logprobs( all_num_logprobs: List[int] = [] for seq_group_metadata in seq_group_metadata_list: - num_logprobs = ( - seq_group_metadata.sampling_params.request_sample_logprobs) + num_logprobs = (seq_group_metadata.sampling_params.logprobs) if num_logprobs is None: num_logprobs = 0 all_num_logprobs.append(num_logprobs) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 535874a1fd6de..3f6fc33d5cae0 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -59,10 +59,9 @@ def _assert_valid_sample_logprobs_prompt_logprobs( """ if isinstance(params, SamplingParams) and ( - (params.request_sample_logprobs - and params.request_sample_logprobs > max_logprobs) or - (params.request_prompt_logprobs - and params.request_prompt_logprobs > max_logprobs)): + (params.logprobs and params.logprobs > max_logprobs) or + (params.prompt_logprobs + and params.prompt_logprobs > max_logprobs)): raise ValueError(f"Cannot request more than " f"{max_logprobs} logprobs or prompt logprobs.") @@ -167,8 +166,8 @@ def process_inputs( sampling_params.output_kind, sampling_params.stop, sampling_params.include_stop_str_in_output, - sampling_params.request_sample_logprobs, - sampling_params.request_prompt_logprobs, + sampling_params.logprobs, + sampling_params.prompt_logprobs, ) # Make Request for EngineCore. diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 7fd37f2effe0c..bf789c5a01f66 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -48,8 +48,8 @@ def __init__( self._all_token_ids: List[int] = self.prompt_token_ids.copy() # Number of sample logprobs and prompt logprobs requested, # respectively - self.request_sample_logprobs = sampling_params.request_sample_logprobs - self.request_prompt_logprobs = sampling_params.request_prompt_logprobs + self.request_sample_logprobs = sampling_params.logprobs + self.request_prompt_logprobs = sampling_params.prompt_logprobs # If sample logprobs are enabled, the number of sample logprobs cannot # be anticipated in advance (because the LLM is partially responsible # for deciding when the completion is finished.) So, diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 1d59d798896f6..d88350e8303a9 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -150,13 +150,13 @@ def add_request( self.generators[req_index] = request.generator - num_logprobs = sampling_params.request_sample_logprobs - num_prompt_logprobs = sampling_params.request_prompt_logprobs + num_logprobs = sampling_params.logprobs + num_prompt_logprobs = sampling_params.prompt_logprobs if num_logprobs is not None and num_logprobs > 0: self.num_logprobs[req_id] = num_logprobs if num_prompt_logprobs is not None and num_prompt_logprobs > 0: self.num_prompt_logprobs[req_id] = num_prompt_logprobs - if sampling_params.request_prompt_logprobs: + if sampling_params.prompt_logprobs: self.prompt_logprob_reqs.add(req_id) def remove_request(self, req_id: str) -> Optional[int]: diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 42ed3fa39abf3..0a7699cba1f32 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -846,8 +846,8 @@ def _prepare_prompt( lora_index_mapping += [lora_id] * (max_prompt_len - context_len) lora_prompt_mapping.extend( [lora_id] * - (max_prompt_len - context_len if seq_group_metadata. - sampling_params.request_prompt_logprobs else 1)) + (max_prompt_len - context_len + if seq_group_metadata.sampling_params.prompt_logprobs else 1)) input_tokens = make_tensor_with_pad(input_tokens, max_len=max_prompt_len, @@ -1154,8 +1154,8 @@ def prepare_input_tensors( paddings = list(itertools.accumulate(paddings)) paddings_prompt_logprobs = [] for i, seq_group_metadata in enumerate(seq_group_metadata_list): - if (seq_group_metadata.sampling_params.request_prompt_logprobs - is not None and seq_group_metadata.is_prompt): + if (seq_group_metadata.sampling_params.prompt_logprobs is not None + and seq_group_metadata.is_prompt): paddings_prompt_logprobs += ([paddings[i]] * seq_lens[i]) paddings = torch.tensor( paddings_prompt_logprobs if paddings_prompt_logprobs else paddings, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index a27ada83d5da7..1bc5f65c7127f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -625,8 +625,8 @@ def _compute_lora_input(self, inter_data: InterDataForSeqGroup, inter_data.lora_prompt_mapping.append( [lora_id] * (query_len if seq_group_metadata.sampling_params - and seq_group_metadata.sampling_params.request_prompt_logprobs - is not None else 1)) + and seq_group_metadata.sampling_params.prompt_logprobs is not None + else 1)) def _compute_prompt_adapter_input( self, inter_data: InterDataForSeqGroup, @@ -653,8 +653,8 @@ def _compute_prompt_adapter_input( prompt_adapter_id ] * num_tokens + [0] * (query_len - num_tokens) inter_data.prompt_adapter_prompt_mapping = [prompt_adapter_id] * ( - query_len if seq_group_metadata.sampling_params and - seq_group_metadata.sampling_params.request_prompt_logprobs else 1) + query_len if seq_group_metadata.sampling_params + and seq_group_metadata.sampling_params.prompt_logprobs else 1) def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, seq_group_metadata: SequenceGroupMetadata): diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index 0783fed12daf8..3ca0d88a42183 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -775,14 +775,12 @@ def _pythonize_sampler_output( seq_groups = sampling_metadata.seq_groups prompt_logprobs_are_requested_for_prefill = any([ - sg.sampling_params.request_prompt_logprobs is not None and sg.is_prompt + sg.sampling_params.prompt_logprobs is not None and sg.is_prompt for sg in seq_groups ]) any_logprobs_are_requested = ( - prompt_logprobs_are_requested_for_prefill or any([ - sg.sampling_params.request_sample_logprobs is not None - for sg in seq_groups - ])) + prompt_logprobs_are_requested_for_prefill + or any([sg.sampling_params.logprobs is not None for sg in seq_groups])) if prompt_logprobs_are_requested_for_prefill: # CPU GPU sync, after gathering *only* sampled tokens (since diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 742dfdfce6cd0..9a054eb8a4cf7 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -520,10 +520,10 @@ def _prepare_sample( f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU " "backend.") n.append(sampling_params.n) - if sampling_params.request_sample_logprobs is not None: + if sampling_params.logprobs is not None: raise NotImplementedError( "logprobs is not currently supported by the TPU backend.") - if sampling_params.request_prompt_logprobs is not None: + if sampling_params.prompt_logprobs is not None: raise NotImplementedError( "prompt_logprobs is not currently supported by the TPU " "backend.") From dc63ac12513dd55952701115d53e614cf21a16a9 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 9 Dec 2024 18:24:31 +0000 Subject: [PATCH 15/19] removing some unnecessary changes' Signed-off-by: Andrew Feldman --- vllm/sampling_params.py | 12 +++++------- vllm/spec_decode/util.py | 2 +- vllm/v1/core/scheduler.py | 6 +++--- 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 55664c6cf787a..fc77f3ca529b2 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -184,8 +184,6 @@ class SamplingParams( ignore_eos: bool = False max_tokens: Optional[int] = 16 min_tokens: int = 0 - # Number of sample logprobs and prompt logprobs, - # respectively, requested logprobs: Optional[int] = None prompt_logprobs: Optional[int] = None # NOTE: This parameter is only exposed at the engine level for now. @@ -328,7 +326,7 @@ def __post_init__(self) -> None: else: self.bad_words = list(self.bad_words) - self.logprobs = (1 if self.logprobs is True else self.logprobs) + self.logprobs = 1 if self.logprobs is True else self.logprobs self.prompt_logprobs = (1 if self.prompt_logprobs is True else self.prompt_logprobs) @@ -387,10 +385,10 @@ def _verify_args(self) -> None: raise ValueError( f"min_tokens must be less than or equal to " f"max_tokens={self.max_tokens}, got {self.min_tokens}.") - if (self.logprobs is not None and self.logprobs < 0): - raise ValueError(f"logprobs must be non-negative, " - f"got {self.logprobs}.") - if (self.prompt_logprobs is not None and self.prompt_logprobs < 0): + if self.logprobs is not None and self.logprobs < 0: + raise ValueError( + f"logprobs must be non-negative, got {self.logprobs}.") + if self.prompt_logprobs is not None and self.prompt_logprobs < 0: raise ValueError(f"prompt_logprobs must be non-negative, got " f"{self.prompt_logprobs}.") if (self.truncate_prompt_tokens is not None diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index 0b6003673578e..da8706658d09a 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -23,7 +23,7 @@ def get_all_num_logprobs( all_num_logprobs: List[int] = [] for seq_group_metadata in seq_group_metadata_list: - num_logprobs = (seq_group_metadata.sampling_params.logprobs) + num_logprobs = seq_group_metadata.sampling_params.logprobs if num_logprobs is None: num_logprobs = 0 all_num_logprobs.append(num_logprobs) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index b71d1b3718528..ecf1d105d4d65 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -156,9 +156,9 @@ def schedule(self) -> "SchedulerOutput": ] num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens + req_index += 1 has_partial_request = (request.num_computed_tokens + num_new_tokens < request.num_tokens) - req_index += 1 # Encoder-related. if encoder_inputs_to_schedule: @@ -234,8 +234,8 @@ def schedule(self) -> "SchedulerOutput": token_budget -= num_new_tokens request.status = RequestStatus.RUNNING request.num_computed_tokens = num_computed_tokens - has_partial_request = (request.num_computed_tokens + - num_new_tokens < request.num_tokens) + has_partial_request = (num_computed_tokens + num_new_tokens < + request.num_tokens) # Encoder-related. if encoder_inputs_to_schedule: From 4f304083c27351faca321f987c07eb7ee1612577 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 9 Dec 2024 18:27:32 +0000 Subject: [PATCH 16/19] removed fast checks Signed-off-by: Andrew Feldman --- tests/v1/sample/test_logprobs.py | 38 -------------------------------- 1 file changed, 38 deletions(-) diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py index 68c72c63786ec..275f6b8335f4a 100644 --- a/tests/v1/sample/test_logprobs.py +++ b/tests/v1/sample/test_logprobs.py @@ -1,4 +1,3 @@ -import os from typing import List import pytest @@ -240,43 +239,6 @@ def test_get_logprobs_and_prompt_logprobs( monkeypatch=monkeypatch) -# LLM engine v1 -@pytest.mark.skipif(os.getenv("VLLM_V1_FAST_TESTS") != "1", - reason="vLLM v1 fast tests not enabled by " - "VLLM_V1_FAST_TESTS=\"1\" in the environment.") -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", - ["half"]) # needed for comparing logprobs with HF -@pytest.mark.parametrize("max_num_batched_tokens", [128]) -@pytest.mark.parametrize("batch_logprobs_composition", - ["NONE", "SAMPLE", "PROMPT", "SAMPLE_PROMPT"]) -def test_fast_get_logprobs_and_prompt_logprobs( - hf_runner, - vllm_runner, - model: str, - dtype: str, - batch_logprobs_composition: str, - max_num_batched_tokens: int, - example_prompts, - monkeypatch, -) -> None: - """Fast test: V1 Engine logprobs & prompt logprobs - - Faster version of `test_get_logprobs_and_prompt_logprobs` with - fewer test cases. - """ - _test_case_get_logprobs_and_prompt_logprobs( - hf_runner=hf_runner, - vllm_runner=vllm_runner, - model=model, - dtype=dtype, - detokenize=True, - batch_logprobs_composition=batch_logprobs_composition, - max_num_batched_tokens=max_num_batched_tokens, - example_prompts=example_prompts, - monkeypatch=monkeypatch) - - def test_max_logprobs(monkeypatch): """vLLM v1 engine should fail a request with `logprobs > max_logprobs` From 77488cb324b94a8bf5bfc5ff07a0137bf5633cc5 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 12 Dec 2024 10:53:20 +0000 Subject: [PATCH 17/19] wip test_completion --- .../v1/entrypoints/openai/test_completion.py | 781 ++++++++++++++++++ 1 file changed, 781 insertions(+) create mode 100644 tests/v1/entrypoints/openai/test_completion.py diff --git a/tests/v1/entrypoints/openai/test_completion.py b/tests/v1/entrypoints/openai/test_completion.py new file mode 100644 index 0000000000000..20255d6b33b06 --- /dev/null +++ b/tests/v1/entrypoints/openai/test_completion.py @@ -0,0 +1,781 @@ +# imports for guided decoding tests +import json +import re +import shutil +from tempfile import TemporaryDirectory +from typing import Dict, List, Optional + +import jsonschema +import openai # use the official client for correctness check +import pytest +import pytest_asyncio +# downloading lora to test lora requests +from huggingface_hub import snapshot_download +from openai import BadRequestError +from transformers import AutoTokenizer + +from vllm.transformers_utils.tokenizer import get_tokenizer + +from ...utils import RemoteOpenAIServer + +# any model with a chat template should work here +MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" +# technically these adapters use a different base model, +# but we're not testing generation quality here +LORA_NAME = "typeof/zephyr-7b-beta-lora" +PA_NAME = "swapnilbp/llama_tweet_ptune" +# if PA_NAME changes, PA_NUM_VIRTUAL_TOKENS might also +# need to change to match the prompt adapter +PA_NUM_VIRTUAL_TOKENS = 8 + + +@pytest.fixture(scope="module") +def zephyr_lora_files(): + return snapshot_download(repo_id=LORA_NAME) + + +@pytest.fixture(scope="module") +def zephyr_lora_added_tokens_files(zephyr_lora_files): + tmp_dir = TemporaryDirectory() + tmp_model_dir = f"{tmp_dir.name}/zephyr" + shutil.copytree(zephyr_lora_files, tmp_model_dir) + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + # Copy tokenizer to adapter and add some unique tokens + # 32000, 32001, 32002 + added = tokenizer.add_tokens(["vllm1", "vllm2", "vllm3"], + special_tokens=True) + assert added == 3 + tokenizer.save_pretrained(tmp_model_dir) + yield tmp_model_dir + tmp_dir.cleanup() + + +@pytest.fixture(scope="module") +def zephyr_pa_files(): + return snapshot_download(repo_id=PA_NAME) + + +@pytest.fixture(scope="module") +def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files, + zephyr_pa_files): + return [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "8192", + "--max-num-seqs", + "128", + "--enforce-eager", + # lora config + "--enable-lora", + "--lora-modules", + f"zephyr-lora={zephyr_lora_files}", + f"zephyr-lora2={zephyr_lora_added_tokens_files}", + "--max-lora-rank", + "64", + "--max-cpu-loras", + "2", + # pa config + "--enable-prompt-adapter", + "--prompt-adapters", + f"zephyr-pa={zephyr_pa_files}", + f"zephyr-pa2={zephyr_pa_files}", + "--max-prompt-adapters", + "2", + "--max-prompt-adapter-token", + "128", + ] + + +@pytest.fixture(scope="module", + params=["", "--disable-frontend-multiprocessing"]) +def server(default_server_args, request): + if request.param: + default_server_args.append(request.param) + with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + # first test base model, then test loras, then test prompt adapters + "model_name,num_virtual_tokens", + [(MODEL_NAME, 0), ("zephyr-lora", 0), ("zephyr-lora2", 0), + ("zephyr-pa", PA_NUM_VIRTUAL_TOKENS), + ("zephyr-pa2", PA_NUM_VIRTUAL_TOKENS)], +) +async def test_single_completion(client: openai.AsyncOpenAI, model_name: str, + num_virtual_tokens: int): + completion = await client.completions.create(model=model_name, + prompt="Hello, my name is", + max_tokens=5, + temperature=0.0) + + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 1 + + choice = completion.choices[0] + assert len(choice.text) >= 5 + assert choice.finish_reason == "length" + assert completion.usage == openai.types.CompletionUsage( + completion_tokens=5, + prompt_tokens=6 + num_virtual_tokens, + total_tokens=11 + num_virtual_tokens) + + # test using token IDs + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + ) + assert len(completion.choices[0].text) >= 1 + assert completion.choices[0].prompt_logprobs is None + + +@pytest.mark.asyncio +async def test_added_lora_tokens(client: openai.AsyncOpenAI): + # test using token IDs + completion = await client.completions.create( + model="zephyr-lora2", + prompt=[0, 0, 32000, 32001, 32002], + echo=True, + max_tokens=5, + temperature=0.0, + ) + # Added tokens should appear in tokenized prompt + assert completion.choices[0].text.startswith("vllm1vllm2vllm3") + + +@pytest.mark.asyncio +async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI): + # test using token IDs + with pytest.raises(openai.BadRequestError, match="out of vocabulary"): + # Added tokens should be rejected by the base model + await client.completions.create( + model=MODEL_NAME, + prompt=[0, 0, 32000, 32001, 32002], + echo=True, + max_tokens=5, + temperature=0.0, + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + # first test base model, then test loras, then test prompt adapters + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-lora2", "zephyr-pa", "zephyr-pa2"], +) +async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str): + # test using token IDs + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + logprobs=None, + ) + choice = completion.choices[0] + assert choice.logprobs is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + # just test 1 lora and 1 pa hereafter + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-pa"], +) +async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str): + # test using token IDs + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + logprobs=0, + ) + choice = completion.choices[0] + assert choice.logprobs is not None + assert choice.logprobs.token_logprobs is not None + assert choice.logprobs.top_logprobs is not None + assert len(choice.logprobs.top_logprobs[0]) == 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-pa"], +) +async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str): + # test using token IDs + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + logprobs=5, + ) + choice = completion.choices[0] + assert choice.logprobs is not None + assert choice.logprobs.token_logprobs is not None + assert choice.logprobs.top_logprobs is not None + assert 5 <= len(choice.logprobs.top_logprobs[0]) <= 6 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-pa"], +) +async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI, + model_name: str): + + with pytest.raises( + (openai.BadRequestError, openai.APIError)): # test using token IDs + await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + # vLLM has higher default max_logprobs (20 instead of 5) to support + # both Completion API and Chat Completion API + logprobs=21, + ) + ... + with pytest.raises( + (openai.BadRequestError, openai.APIError)): # test using token IDs + stream = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + # vLLM has higher default max_logprobs (20 instead of 5) to support + # both Completion API and Chat Completion API + logprobs=30, + stream=True, + ) + async for chunk in stream: + ... + + # the server should still work afterwards + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + ) + assert len(completion.choices[0].text) >= 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name, prompt_logprobs", [(MODEL_NAME, -1), + (MODEL_NAME, 0), + (MODEL_NAME, 1), + (MODEL_NAME, None)]) +async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI, + model_name: str, + prompt_logprobs: Optional[int]): + params: Dict = { + "prompt": ["A robot may not injure another robot", "My name is"], + "model": model_name, + } + if prompt_logprobs is not None: + params["extra_body"] = {"prompt_logprobs": prompt_logprobs} + + if prompt_logprobs is not None and prompt_logprobs < 0: + with pytest.raises(BadRequestError): + await client.completions.create(**params) + else: + completion = await client.completions.create(**params) + if prompt_logprobs is not None: + assert completion.choices[0].prompt_logprobs is not None + assert len(completion.choices[0].prompt_logprobs) > 0 + + assert completion.choices[1].prompt_logprobs is not None + assert len(completion.choices[1].prompt_logprobs) > 0 + + else: + assert completion.choices[0].prompt_logprobs is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-pa"], +) +async def test_completion_streaming(client: openai.AsyncOpenAI, + model_name: str): + prompt = "What is an LLM?" + + single_completion = await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + ) + single_output = single_completion.choices[0].text + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True) + chunks: List[str] = [] + finish_reason_count = 0 + async for chunk in stream: + chunks.append(chunk.choices[0].text) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + # finish reason should only return in last block + assert finish_reason_count == 1 + assert chunk.choices[0].finish_reason == "length" + assert chunk.choices[0].text + assert "".join(chunks) == single_output + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-pa"], +) +async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str): + """Streaming for parallel sampling. + The tokens from multiple samples, are flattened into a single stream, + with an index to indicate which sample the token belongs to. + """ + + prompt = "What is an LLM?" + n = 3 + max_tokens = 5 + + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=max_tokens, + n=n, + stream=True) + chunks: List[List[str]] = [[] for i in range(n)] + finish_reason_count = 0 + async for chunk in stream: + index = chunk.choices[0].index + text = chunk.choices[0].text + chunks[index].append(text) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert finish_reason_count == n + for chunk in chunks: + assert len(chunk) == max_tokens + print("".join(chunk)) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-pa"], +) +async def test_completion_stream_options(client: openai.AsyncOpenAI, + model_name: str): + prompt = "What is the capital of France?" + + # Test stream=True, stream_options= + # {"include_usage": False, "continuous_usage_stats": False} + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": False, + "continuous_usage_stats": + False, + }) + + async for chunk in stream: + assert chunk.usage is None + + # Test stream=True, stream_options= + # {"include_usage": False, "continuous_usage_stats": True} + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": False, + "continuous_usage_stats": + True, + }) + async for chunk in stream: + assert chunk.usage is None + + # Test stream=True, stream_options= + # {"include_usage": True, "continuous_usage_stats": False} + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": True, + "continuous_usage_stats": + False, + }) + async for chunk in stream: + if chunk.choices[0].finish_reason is None: + assert chunk.usage is None + else: + assert chunk.usage is None + final_chunk = await stream.__anext__() + assert final_chunk.usage is not None + assert final_chunk.usage.prompt_tokens > 0 + assert final_chunk.usage.completion_tokens > 0 + assert final_chunk.usage.total_tokens == ( + final_chunk.usage.prompt_tokens + + final_chunk.usage.completion_tokens) + assert final_chunk.choices == [] + + # Test stream=True, stream_options= + # {"include_usage": True, "continuous_usage_stats": True} + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": True, + "continuous_usage_stats": + True, + }) + async for chunk in stream: + assert chunk.usage is not None + assert chunk.usage.prompt_tokens > 0 + assert chunk.usage.completion_tokens > 0 + assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens + + chunk.usage.completion_tokens) + if chunk.choices[0].finish_reason is not None: + final_chunk = await stream.__anext__() + assert final_chunk.usage is not None + assert final_chunk.usage.prompt_tokens > 0 + assert final_chunk.usage.completion_tokens > 0 + assert final_chunk.usage.total_tokens == ( + final_chunk.usage.prompt_tokens + + final_chunk.usage.completion_tokens) + assert final_chunk.choices == [] + + # Test stream=False, stream_options= + # {"include_usage": None} + with pytest.raises(BadRequestError): + await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"include_usage": None}) + + # Test stream=False, stream_options= + # {"include_usage": True} + with pytest.raises(BadRequestError): + await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"include_usage": True}) + + # Test stream=False, stream_options= + # {"continuous_usage_stats": None} + with pytest.raises(BadRequestError): + await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"continuous_usage_stats": None}) + + # Test stream=False, stream_options= + # {"continuous_usage_stats": True} + with pytest.raises(BadRequestError): + await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"continuous_usage_stats": True}) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-pa"], +) +async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str): + # test both text and token IDs + for prompts in (["Hello, my name is"] * 2, [[0, 0, 0, 0, 0]] * 2): + # test simple list + batch = await client.completions.create( + model=model_name, + prompt=prompts, + max_tokens=5, + temperature=0.0, + ) + assert len(batch.choices) == 2 + assert batch.choices[0].text == batch.choices[1].text + + # test n = 2 + batch = await client.completions.create( + model=model_name, + prompt=prompts, + n=2, + max_tokens=5, + temperature=0.0, + extra_body=dict( + # NOTE: this has to be true for n > 1 in vLLM, but + # not necessary for official client. + use_beam_search=True), + ) + assert len(batch.choices) == 4 + assert batch.choices[0].text != batch.choices[ + 1].text, "beam search should be different" + assert batch.choices[0].text == batch.choices[ + 2].text, "two copies of the same prompt should be the same" + assert batch.choices[1].text == batch.choices[ + 3].text, "two copies of the same prompt should be the same" + + # test streaming + batch = await client.completions.create( + model=model_name, + prompt=prompts, + max_tokens=5, + temperature=0.0, + stream=True, + ) + texts = [""] * 2 + async for chunk in batch: + assert len(chunk.choices) == 1 + choice = chunk.choices[0] + texts[choice.index] += choice.text + assert texts[0] == texts[1] + + +@pytest.mark.asyncio +async def test_logits_bias(client: openai.AsyncOpenAI): + prompt = "Hello, my name is" + max_tokens = 5 + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + + # Test exclusive selection + token_id = 1000 + completion = await client.completions.create( + model=MODEL_NAME, + prompt=prompt, + max_tokens=max_tokens, + temperature=0.0, + logit_bias={str(token_id): 100}, + seed=42, + ) + assert len(completion.choices[0].text) >= 5 + response_tokens = tokenizer(completion.choices[0].text, + add_special_tokens=False)["input_ids"] + expected_tokens = tokenizer(tokenizer.decode([token_id] * 5), + add_special_tokens=False)["input_ids"] + assert all([ + response == expected + for response, expected in zip(response_tokens, expected_tokens) + ]) + + # Test ban + completion = await client.completions.create( + model=MODEL_NAME, + prompt=prompt, + max_tokens=max_tokens, + temperature=0.0, + ) + response_tokens = tokenizer(completion.choices[0].text, + add_special_tokens=False)["input_ids"] + first_response = completion.choices[0].text + completion = await client.completions.create( + model=MODEL_NAME, + prompt=prompt, + max_tokens=max_tokens, + temperature=0.0, + logit_bias={str(token): -100 + for token in response_tokens}, + ) + assert first_response != completion.choices[0].text + + +@pytest.mark.asyncio +async def test_allowed_token_ids(client: openai.AsyncOpenAI): + prompt = "Hello, my name is" + max_tokens = 1 + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + + # Test exclusive selection + allowed_ids = [21555, 21557, 21558] + completion = await client.completions.create( + model=MODEL_NAME, + prompt=prompt, + max_tokens=max_tokens, + temperature=0.0, + seed=42, + extra_body=dict(allowed_token_ids=allowed_ids), + logprobs=1, + ) + response_tokens = completion.choices[0].logprobs.tokens + assert len(response_tokens) == 1 + assert tokenizer.convert_tokens_to_ids(response_tokens)[0] in allowed_ids + + +@pytest.mark.asyncio +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_json_completion(client: openai.AsyncOpenAI, + guided_decoding_backend: str, + sample_json_schema): + completion = await client.completions.create( + model=MODEL_NAME, + prompt=f"Give an example JSON for an employee profile " + f"that fits this schema: {sample_json_schema}", + n=3, + temperature=1.0, + max_tokens=500, + extra_body=dict(guided_json=sample_json_schema, + guided_decoding_backend=guided_decoding_backend)) + + assert completion.id is not None + assert len(completion.choices) == 3 + for i in range(3): + output_json = json.loads(completion.choices[i].text) + jsonschema.validate(instance=output_json, schema=sample_json_schema) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_regex_completion(client: openai.AsyncOpenAI, + guided_decoding_backend: str, + sample_regex): + completion = await client.completions.create( + model=MODEL_NAME, + prompt=f"Give an example IPv4 address with this regex: {sample_regex}", + n=3, + temperature=1.0, + max_tokens=20, + extra_body=dict(guided_regex=sample_regex, + guided_decoding_backend=guided_decoding_backend)) + + assert completion.id is not None + assert len(completion.choices) == 3 + for i in range(3): + assert re.fullmatch(sample_regex, + completion.choices[i].text) is not None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_choice_completion(client: openai.AsyncOpenAI, + guided_decoding_backend: str, + sample_guided_choice): + completion = await client.completions.create( + model=MODEL_NAME, + prompt="The best language for type-safe systems programming is ", + n=2, + temperature=1.0, + max_tokens=10, + extra_body=dict(guided_choice=sample_guided_choice, + guided_decoding_backend=guided_decoding_backend)) + + assert completion.id is not None + assert len(completion.choices) == 2 + for i in range(2): + assert completion.choices[i].text in sample_guided_choice + + +@pytest.mark.asyncio +async def test_guided_grammar(client: openai.AsyncOpenAI, + sample_sql_statements): + + completion = await client.completions.create( + model=MODEL_NAME, + prompt=("Generate a sql state that select col_1 from " + "table_1 where it is equals to 1"), + temperature=1.0, + max_tokens=500, + extra_body=dict(guided_grammar=sample_sql_statements)) + + content = completion.choices[0].text + + # use Lark to parse the output, and make sure it's a valid parse tree + from lark import Lark + parser = Lark(sample_sql_statements) + parser.parse(content) + + # remove spaces for comparison b/c we removed them in the grammar + ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(" ", "") + + assert content.strip() == ground_truth + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + # first test base model, then test loras + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-lora2"], +) +@pytest.mark.parametrize("logprobs_arg", [1, 0]) +async def test_echo_logprob_completion(client: openai.AsyncOpenAI, + model_name: str, logprobs_arg: int): + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + # test using text and token IDs + for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]): + completion = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + echo=True, + logprobs=logprobs_arg) + + prompt_text = tokenizer.decode(prompt) if isinstance(prompt, + list) else prompt + assert re.search(r"^" + prompt_text, completion.choices[0].text) + logprobs = completion.choices[0].logprobs + assert logprobs is not None + assert len(logprobs.text_offset) > 5 + assert (len(logprobs.token_logprobs) > 5 + and logprobs.token_logprobs[0] is None) + assert (len(logprobs.top_logprobs) > 5 + and logprobs.top_logprobs[0] is None) + for top_logprobs in logprobs.top_logprobs[1:]: + assert max(logprobs_arg, + 1) <= len(top_logprobs) <= logprobs_arg + 1 + assert len(logprobs.tokens) > 5 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_decoding_type_error(client: openai.AsyncOpenAI, + guided_decoding_backend: str, + sample_json_schema, sample_regex): + with pytest.raises(openai.BadRequestError): + _ = await client.completions.create( + model=MODEL_NAME, + prompt="Give an example JSON that fits this schema: 42", + extra_body=dict(guided_json=42, + guided_decoding_backend=guided_decoding_backend)) + + with pytest.raises(openai.BadRequestError): + _ = await client.completions.create( + model=MODEL_NAME, + prompt="Give an example string that fits this regex", + extra_body=dict(guided_regex=sample_regex, + guided_json=sample_json_schema)) \ No newline at end of file From f1a689c2d0b4a90ff96216fce5eb0cae44262fa2 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 12 Dec 2024 11:34:53 +0000 Subject: [PATCH 18/19] toward completion tests Signed-off-by: Andrew Feldman --- .../v1/entrypoints/openai/test_completion.py | 291 ++---------------- 1 file changed, 18 insertions(+), 273 deletions(-) diff --git a/tests/v1/entrypoints/openai/test_completion.py b/tests/v1/entrypoints/openai/test_completion.py index 20255d6b33b06..1a3d458b118ab 100644 --- a/tests/v1/entrypoints/openai/test_completion.py +++ b/tests/v1/entrypoints/openai/test_completion.py @@ -1,63 +1,21 @@ # imports for guided decoding tests -import json import re -import shutil -from tempfile import TemporaryDirectory from typing import Dict, List, Optional -import jsonschema import openai # use the official client for correctness check import pytest import pytest_asyncio -# downloading lora to test lora requests -from huggingface_hub import snapshot_download from openai import BadRequestError -from transformers import AutoTokenizer +from tests.utils import RemoteOpenAIServer from vllm.transformers_utils.tokenizer import get_tokenizer -from ...utils import RemoteOpenAIServer - # any model with a chat template should work here MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" -# technically these adapters use a different base model, -# but we're not testing generation quality here -LORA_NAME = "typeof/zephyr-7b-beta-lora" -PA_NAME = "swapnilbp/llama_tweet_ptune" -# if PA_NAME changes, PA_NUM_VIRTUAL_TOKENS might also -# need to change to match the prompt adapter -PA_NUM_VIRTUAL_TOKENS = 8 - - -@pytest.fixture(scope="module") -def zephyr_lora_files(): - return snapshot_download(repo_id=LORA_NAME) - - -@pytest.fixture(scope="module") -def zephyr_lora_added_tokens_files(zephyr_lora_files): - tmp_dir = TemporaryDirectory() - tmp_model_dir = f"{tmp_dir.name}/zephyr" - shutil.copytree(zephyr_lora_files, tmp_model_dir) - tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - # Copy tokenizer to adapter and add some unique tokens - # 32000, 32001, 32002 - added = tokenizer.add_tokens(["vllm1", "vllm2", "vllm3"], - special_tokens=True) - assert added == 3 - tokenizer.save_pretrained(tmp_model_dir) - yield tmp_model_dir - tmp_dir.cleanup() - - -@pytest.fixture(scope="module") -def zephyr_pa_files(): - return snapshot_download(repo_id=PA_NAME) @pytest.fixture(scope="module") -def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files, - zephyr_pa_files): +def default_server_args(): return [ # use half precision for speed and memory savings in CI environment "--dtype", @@ -67,24 +25,6 @@ def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files, "--max-num-seqs", "128", "--enforce-eager", - # lora config - "--enable-lora", - "--lora-modules", - f"zephyr-lora={zephyr_lora_files}", - f"zephyr-lora2={zephyr_lora_added_tokens_files}", - "--max-lora-rank", - "64", - "--max-cpu-loras", - "2", - # pa config - "--enable-prompt-adapter", - "--prompt-adapters", - f"zephyr-pa={zephyr_pa_files}", - f"zephyr-pa2={zephyr_pa_files}", - "--max-prompt-adapters", - "2", - "--max-prompt-adapter-token", - "128", ] @@ -105,14 +45,11 @@ async def client(server): @pytest.mark.asyncio @pytest.mark.parametrize( - # first test base model, then test loras, then test prompt adapters - "model_name,num_virtual_tokens", - [(MODEL_NAME, 0), ("zephyr-lora", 0), ("zephyr-lora2", 0), - ("zephyr-pa", PA_NUM_VIRTUAL_TOKENS), - ("zephyr-pa2", PA_NUM_VIRTUAL_TOKENS)], + "model_name", + [MODEL_NAME], ) -async def test_single_completion(client: openai.AsyncOpenAI, model_name: str, - num_virtual_tokens: int): +async def test_single_completion(client: openai.AsyncOpenAI, + model_name: str) -> None: completion = await client.completions.create(model=model_name, prompt="Hello, my name is", max_tokens=5, @@ -125,9 +62,7 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str, assert len(choice.text) >= 5 assert choice.finish_reason == "length" assert completion.usage == openai.types.CompletionUsage( - completion_tokens=5, - prompt_tokens=6 + num_virtual_tokens, - total_tokens=11 + num_virtual_tokens) + completion_tokens=5, prompt_tokens=6, total_tokens=11) # test using token IDs completion = await client.completions.create( @@ -140,39 +75,10 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str, assert completion.choices[0].prompt_logprobs is None -@pytest.mark.asyncio -async def test_added_lora_tokens(client: openai.AsyncOpenAI): - # test using token IDs - completion = await client.completions.create( - model="zephyr-lora2", - prompt=[0, 0, 32000, 32001, 32002], - echo=True, - max_tokens=5, - temperature=0.0, - ) - # Added tokens should appear in tokenized prompt - assert completion.choices[0].text.startswith("vllm1vllm2vllm3") - - -@pytest.mark.asyncio -async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI): - # test using token IDs - with pytest.raises(openai.BadRequestError, match="out of vocabulary"): - # Added tokens should be rejected by the base model - await client.completions.create( - model=MODEL_NAME, - prompt=[0, 0, 32000, 32001, 32002], - echo=True, - max_tokens=5, - temperature=0.0, - ) - - @pytest.mark.asyncio @pytest.mark.parametrize( - # first test base model, then test loras, then test prompt adapters "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-lora2", "zephyr-pa", "zephyr-pa2"], + [MODEL_NAME], ) async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str): # test using token IDs @@ -189,9 +95,8 @@ async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize( - # just test 1 lora and 1 pa hereafter "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-pa"], + [MODEL_NAME], ) async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str): # test using token IDs @@ -212,7 +117,7 @@ async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize( "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-pa"], + [MODEL_NAME], ) async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str): # test using token IDs @@ -233,10 +138,10 @@ async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize( "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-pa"], + [MODEL_NAME], ) async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI, - model_name: str): + model_name: str) -> None: with pytest.raises( (openai.BadRequestError, openai.APIError)): # test using token IDs @@ -309,10 +214,10 @@ async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize( "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-pa"], + [MODEL_NAME], ) async def test_completion_streaming(client: openai.AsyncOpenAI, - model_name: str): + model_name: str) -> None: prompt = "What is an LLM?" single_completion = await client.completions.create( @@ -343,7 +248,7 @@ async def test_completion_streaming(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize( "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-pa"], + [MODEL_NAME], ) async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str): """Streaming for parallel sampling. @@ -377,7 +282,7 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize( "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-pa"], + [MODEL_NAME], ) async def test_completion_stream_options(client: openai.AsyncOpenAI, model_name: str): @@ -514,7 +419,7 @@ async def test_completion_stream_options(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize( "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-pa"], + [MODEL_NAME], ) async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str): # test both text and token IDs @@ -565,53 +470,6 @@ async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str): assert texts[0] == texts[1] -@pytest.mark.asyncio -async def test_logits_bias(client: openai.AsyncOpenAI): - prompt = "Hello, my name is" - max_tokens = 5 - tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) - - # Test exclusive selection - token_id = 1000 - completion = await client.completions.create( - model=MODEL_NAME, - prompt=prompt, - max_tokens=max_tokens, - temperature=0.0, - logit_bias={str(token_id): 100}, - seed=42, - ) - assert len(completion.choices[0].text) >= 5 - response_tokens = tokenizer(completion.choices[0].text, - add_special_tokens=False)["input_ids"] - expected_tokens = tokenizer(tokenizer.decode([token_id] * 5), - add_special_tokens=False)["input_ids"] - assert all([ - response == expected - for response, expected in zip(response_tokens, expected_tokens) - ]) - - # Test ban - completion = await client.completions.create( - model=MODEL_NAME, - prompt=prompt, - max_tokens=max_tokens, - temperature=0.0, - ) - response_tokens = tokenizer(completion.choices[0].text, - add_special_tokens=False)["input_ids"] - first_response = completion.choices[0].text - completion = await client.completions.create( - model=MODEL_NAME, - prompt=prompt, - max_tokens=max_tokens, - temperature=0.0, - logit_bias={str(token): -100 - for token in response_tokens}, - ) - assert first_response != completion.choices[0].text - - @pytest.mark.asyncio async def test_allowed_token_ids(client: openai.AsyncOpenAI): prompt = "Hello, my name is" @@ -634,102 +492,10 @@ async def test_allowed_token_ids(client: openai.AsyncOpenAI): assert tokenizer.convert_tokens_to_ids(response_tokens)[0] in allowed_ids -@pytest.mark.asyncio -@pytest.mark.parametrize("guided_decoding_backend", - ["outlines", "lm-format-enforcer"]) -async def test_guided_json_completion(client: openai.AsyncOpenAI, - guided_decoding_backend: str, - sample_json_schema): - completion = await client.completions.create( - model=MODEL_NAME, - prompt=f"Give an example JSON for an employee profile " - f"that fits this schema: {sample_json_schema}", - n=3, - temperature=1.0, - max_tokens=500, - extra_body=dict(guided_json=sample_json_schema, - guided_decoding_backend=guided_decoding_backend)) - - assert completion.id is not None - assert len(completion.choices) == 3 - for i in range(3): - output_json = json.loads(completion.choices[i].text) - jsonschema.validate(instance=output_json, schema=sample_json_schema) - - -@pytest.mark.asyncio -@pytest.mark.parametrize("guided_decoding_backend", - ["outlines", "lm-format-enforcer"]) -async def test_guided_regex_completion(client: openai.AsyncOpenAI, - guided_decoding_backend: str, - sample_regex): - completion = await client.completions.create( - model=MODEL_NAME, - prompt=f"Give an example IPv4 address with this regex: {sample_regex}", - n=3, - temperature=1.0, - max_tokens=20, - extra_body=dict(guided_regex=sample_regex, - guided_decoding_backend=guided_decoding_backend)) - - assert completion.id is not None - assert len(completion.choices) == 3 - for i in range(3): - assert re.fullmatch(sample_regex, - completion.choices[i].text) is not None - - -@pytest.mark.asyncio -@pytest.mark.parametrize("guided_decoding_backend", - ["outlines", "lm-format-enforcer"]) -async def test_guided_choice_completion(client: openai.AsyncOpenAI, - guided_decoding_backend: str, - sample_guided_choice): - completion = await client.completions.create( - model=MODEL_NAME, - prompt="The best language for type-safe systems programming is ", - n=2, - temperature=1.0, - max_tokens=10, - extra_body=dict(guided_choice=sample_guided_choice, - guided_decoding_backend=guided_decoding_backend)) - - assert completion.id is not None - assert len(completion.choices) == 2 - for i in range(2): - assert completion.choices[i].text in sample_guided_choice - - -@pytest.mark.asyncio -async def test_guided_grammar(client: openai.AsyncOpenAI, - sample_sql_statements): - - completion = await client.completions.create( - model=MODEL_NAME, - prompt=("Generate a sql state that select col_1 from " - "table_1 where it is equals to 1"), - temperature=1.0, - max_tokens=500, - extra_body=dict(guided_grammar=sample_sql_statements)) - - content = completion.choices[0].text - - # use Lark to parse the output, and make sure it's a valid parse tree - from lark import Lark - parser = Lark(sample_sql_statements) - parser.parse(content) - - # remove spaces for comparison b/c we removed them in the grammar - ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(" ", "") - - assert content.strip() == ground_truth - - @pytest.mark.asyncio @pytest.mark.parametrize( - # first test base model, then test loras "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-lora2"], + [MODEL_NAME], ) @pytest.mark.parametrize("logprobs_arg", [1, 0]) async def test_echo_logprob_completion(client: openai.AsyncOpenAI, @@ -758,24 +524,3 @@ async def test_echo_logprob_completion(client: openai.AsyncOpenAI, assert max(logprobs_arg, 1) <= len(top_logprobs) <= logprobs_arg + 1 assert len(logprobs.tokens) > 5 - - -@pytest.mark.asyncio -@pytest.mark.parametrize("guided_decoding_backend", - ["outlines", "lm-format-enforcer"]) -async def test_guided_decoding_type_error(client: openai.AsyncOpenAI, - guided_decoding_backend: str, - sample_json_schema, sample_regex): - with pytest.raises(openai.BadRequestError): - _ = await client.completions.create( - model=MODEL_NAME, - prompt="Give an example JSON that fits this schema: 42", - extra_body=dict(guided_json=42, - guided_decoding_backend=guided_decoding_backend)) - - with pytest.raises(openai.BadRequestError): - _ = await client.completions.create( - model=MODEL_NAME, - prompt="Give an example string that fits this regex", - extra_body=dict(guided_regex=sample_regex, - guided_json=sample_json_schema)) \ No newline at end of file From e962aa7e4d74f4e42a5464ba82f2ac41156e803d Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 12 Dec 2024 17:51:40 +0000 Subject: [PATCH 19/19] serialization fix Signed-off-by: Andrew Feldman --- vllm/v1/engine/core.py | 4 ++-- vllm/v1/engine/core_client.py | 5 +++-- vllm/v1/serial_utils.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 33 insertions(+), 4 deletions(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 5fc4f2e425726..bf07dc94bb8f7 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -23,7 +23,7 @@ from vllm.v1.executor.gpu_executor import GPUExecutor from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus -from vllm.v1.serial_utils import PickleEncoder +from vllm.v1.serial_utils import PickleEncoder, custom_enc_hook from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -517,7 +517,7 @@ def process_output_socket(self, output_path: str): """Output socket IO thread.""" # Msgpack serialization encoding. - encoder = msgpack.Encoder() + encoder = msgpack.Encoder(enc_hook=custom_enc_hook) # Reuse send buffer. buffer = bytearray() diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 835963f7ee86c..236d633e8d5da 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -12,7 +12,7 @@ EngineCoreProfile, EngineCoreRequest, EngineCoreRequestType) from vllm.v1.engine.core import EngineCore, EngineCoreProc -from vllm.v1.serial_utils import PickleEncoder +from vllm.v1.serial_utils import PickleEncoder, custom_ext_hook logger = init_logger(__name__) @@ -124,7 +124,8 @@ def __init__( ): # Serialization setup. self.encoder = PickleEncoder() - self.decoder = msgspec.msgpack.Decoder(EngineCoreOutputs) + self.decoder = msgspec.msgpack.Decoder(EngineCoreOutputs, + ext_hook=custom_ext_hook) # ZMQ setup. self.ctx = (zmq.asyncio.Context() if asyncio_mode else zmq.Context()) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index b1cd5c11834f8..76f7076cfa9e0 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -1,4 +1,11 @@ import pickle +from typing import Any + +import numpy as np +from msgspec import msgpack + +CUSTOM_TYPE_CODE_PICKLE = 1 +pickle_types = (np.ndarray, ) class PickleEncoder: @@ -8,3 +15,24 @@ def encode(self, obj): def decode(self, data): return pickle.loads(data) + + +def custom_enc_hook(obj: Any) -> Any: + if isinstance(obj, pickle_types): + # Return an `Ext` object so msgspec serializes it as an extension type. + return msgpack.Ext(CUSTOM_TYPE_CODE_PICKLE, pickle.dumps(obj)) + else: + # Raise a NotImplementedError for other types + raise NotImplementedError( + f"Objects of type {type(obj)} are not supported") + + +def custom_ext_hook(code: int, data: memoryview) -> Any: + if code == CUSTOM_TYPE_CODE_PICKLE: + # This extension type represents a complex number, decode the data + # buffer accordingly. + return pickle.loads(data) + else: + # Raise a NotImplementedError for other extension type codes + raise NotImplementedError( + f"Extension type code {code} is not supported")