From 6d917d0eebd03990edf2443780a5f2506026ea78 Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Sat, 14 Dec 2024 17:54:04 +0000 Subject: [PATCH] Enable mypy checking on V1 code (#11105) Signed-off-by: Mark McLoughlin --- tools/mypy.sh | 1 + vllm/v1/attention/backends/flash_attn.py | 2 ++ vllm/v1/core/kv_cache_manager.py | 10 +++--- vllm/v1/core/kv_cache_utils.py | 17 +++++----- vllm/v1/core/scheduler.py | 1 + vllm/v1/engine/__init__.py | 23 ++++++++----- vllm/v1/engine/async_llm.py | 11 +++--- vllm/v1/engine/core.py | 20 +++++------ vllm/v1/engine/core_client.py | 43 +++++++++++++----------- vllm/v1/engine/detokenizer.py | 4 +-- vllm/v1/engine/llm_engine.py | 3 +- vllm/v1/engine/mm_input_mapper.py | 20 +++++++---- vllm/v1/engine/processor.py | 2 +- vllm/v1/executor/abstract.py | 12 ++----- vllm/v1/executor/multiproc_executor.py | 15 +++++---- vllm/v1/executor/uniproc_executor.py | 7 ++-- vllm/v1/request.py | 3 +- vllm/v1/utils.py | 42 ++++++++++++++--------- vllm/v1/worker/gpu_input_batch.py | 1 + vllm/v1/worker/gpu_model_runner.py | 42 ++++++++++++++--------- vllm/v1/worker/gpu_worker.py | 2 +- 21 files changed, 160 insertions(+), 121 deletions(-) diff --git a/tools/mypy.sh b/tools/mypy.sh index e984e739d70cf..2454ff9fde466 100755 --- a/tools/mypy.sh +++ b/tools/mypy.sh @@ -29,3 +29,4 @@ run_mypy vllm/plugins run_mypy vllm/prompt_adapter run_mypy vllm/spec_decode run_mypy vllm/worker +run_mypy vllm/v1 diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index c9f04ace644c7..026a0292cc339 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -135,6 +135,8 @@ def forward( assert k_scale == 1.0 and v_scale == 1.0, ( "key/v_scale is not supported in FlashAttention.") + assert output is not None, "Output tensor must be provided." + if attn_metadata is None: # Profiling run. return output diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 8044481a9cd6a..aaa44c930e324 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Dict, List, Optional +from typing import Dict, Iterable, List, Optional from vllm.logger import init_logger from vllm.utils import cdiv @@ -263,12 +263,13 @@ def free(self, request: Request) -> None: """ # Default to [] in case a request is freed (aborted) before alloc. blocks = self.req_to_blocks.pop(request.request_id, []) + ordered_blocks: Iterable[KVCacheBlock] = blocks if self.enable_caching: # Free blocks in reverse order so that the tail blocks are # freed first. - blocks = reversed(blocks) + ordered_blocks = reversed(blocks) - for block in blocks: + for block in ordered_blocks: block.decr_ref() if block.ref_cnt == 0: self.free_block_queue.append(block) @@ -396,8 +397,7 @@ def _cache_full_blocks( f"{request.request_id}({request})") # Compute the hash of the current block. - block_hash = hash_block_tokens(prev_block_hash_value, - tuple(block_tokens)) + block_hash = hash_block_tokens(prev_block_hash_value, block_tokens) # Update and added the full block to the cache. blk.block_hash = block_hash diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 814e462a91fed..0ba338aa5a3d2 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1,4 +1,5 @@ """KV-Cache Utilities.""" +from collections.abc import Sequence from dataclasses import dataclass from typing import List, NamedTuple, Optional, Tuple @@ -13,7 +14,7 @@ class BlockHashType(NamedTuple): collision happens when the hash value is the same. """ hash_value: int - token_ids: Tuple[int] + token_ids: Tuple[int, ...] @dataclass @@ -79,8 +80,8 @@ def __init__(self, blocks: List[KVCacheBlock]) -> None: self.num_free_blocks = len(blocks) # Initialize the doubly linked list of free blocks. - self.free_list_head = blocks[0] - self.free_list_tail = blocks[-1] + self.free_list_head: Optional[KVCacheBlock] = blocks[0] + self.free_list_tail: Optional[KVCacheBlock] = blocks[-1] for i in range(self.num_free_blocks): if i > 0: blocks[i].prev_free_block = blocks[i - 1] @@ -159,7 +160,7 @@ def get_all_free_blocks(self) -> List[KVCacheBlock]: def hash_block_tokens(parent_block_hash: Optional[int], - curr_block_token_ids: Tuple[int]) -> BlockHashType: + curr_block_token_ids: Sequence[int]) -> BlockHashType: """Computes a hash value corresponding to the contents of a block and the contents of the preceding block(s). The hash value is used for prefix caching. We use LRU cache for this function to avoid recomputing @@ -171,7 +172,7 @@ def hash_block_tokens(parent_block_hash: Optional[int], Args: parent_block_hash: The hash of the parent block. None if this is the first block. - curr_block_token_ids: A tuple of token ids in the current + curr_block_token_ids: A list of token ids in the current block. The current block is assumed to be full. Returns: @@ -179,11 +180,11 @@ def hash_block_tokens(parent_block_hash: Optional[int], The entire tuple is used as the hash key of the block. """ return BlockHashType(hash((parent_block_hash, *curr_block_token_ids)), - curr_block_token_ids) + tuple(curr_block_token_ids)) def hash_request_tokens(block_size: int, - token_ids: List[int]) -> List[BlockHashType]: + token_ids: Sequence[int]) -> List[BlockHashType]: """Computes hash values of a chain of blocks given a sequence of token IDs. The hash value is used for prefix caching. @@ -198,7 +199,7 @@ def hash_request_tokens(block_size: int, parent_block_hash_value = None for start in range(0, len(token_ids), block_size): end = start + block_size - block_token_ids = tuple(token_ids[start:end]) + block_token_ids = token_ids[start:end] # Do not hash the block if it is not full. if len(block_token_ids) < block_size: break diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index f055eed77c372..f76364f64033d 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -152,6 +152,7 @@ def schedule(self) -> "SchedulerOutput": break if not can_schedule: break + assert new_blocks is not None # Schedule the request. scheduled_running_reqs.append(request) diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index abeea052c1fa5..cc0c7ea23469a 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -36,7 +36,7 @@ class EngineCoreRequest: prompt: Optional[str] prompt_token_ids: List[int] mm_inputs: Optional[List[Optional[MultiModalKwargs]]] - mm_hashes: Optional[List[Optional[str]]] + mm_hashes: Optional[List[str]] mm_placeholders: Optional[MultiModalPlaceholderDict] sampling_params: SamplingParams eos_token_id: Optional[int] @@ -44,10 +44,11 @@ class EngineCoreRequest: lora_request: Optional[LoRARequest] -class EngineCoreOutput(msgspec.Struct, - array_like=True, - omit_defaults=True, - gc=False): +class EngineCoreOutput( + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True, # type: ignore[call-arg] + gc=False): # type: ignore[call-arg] request_id: str new_token_ids: List[int] @@ -56,10 +57,11 @@ class EngineCoreOutput(msgspec.Struct, stop_reason: Union[int, str, None] = None -class EngineCoreOutputs(msgspec.Struct, - array_like=True, - omit_defaults=True, - gc=False): +class EngineCoreOutputs( + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True, # type: ignore[call-arg] + gc=False): # type: ignore[call-arg] #NOTE(Nick): We could consider ways to make this more compact, # e.g. columnwise layout and using an int enum for finish/stop reason @@ -81,3 +83,6 @@ class EngineCoreRequestType(enum.Enum): ADD = b'\x00' ABORT = b'\x01' PROFILE = b'\x02' + + +EngineCoreRequestUnion = Union[EngineCoreRequest, EngineCoreProfile, List[str]] diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 24cafeff63d1e..b36de5f66917c 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -81,7 +81,7 @@ def __init__( asyncio_mode=True, ) - self.output_handler = None + self.output_handler: Optional[asyncio.Task] = None def __del__(self): self.shutdown() @@ -126,7 +126,8 @@ def shutdown(self): handler.cancel() @classmethod - def _get_executor_cls(cls, vllm_config: VllmConfig): + def _get_executor_cls(cls, vllm_config: VllmConfig) -> Type[Executor]: + executor_class: Type[Executor] distributed_executor_backend = ( vllm_config.parallel_config.distributed_executor_backend) if distributed_executor_backend == "mp": @@ -361,10 +362,10 @@ async def check_health(self) -> None: logger.debug("Called check_health.") async def start_profile(self) -> None: - await self.engine_core.profile(True) + await self.engine_core.profile_async(True) async def stop_profile(self) -> None: - await self.engine_core.profile(False) + await self.engine_core.profile_async(False) @property def is_running(self) -> bool: @@ -380,7 +381,7 @@ def errored(self) -> bool: @property def dead_error(self) -> BaseException: - return Exception + return Exception() # TODO: implement # Retain V0 name for backwards compatibility. diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index af644fb5fedba..56d4dc67e4a0e 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -5,7 +5,7 @@ import time from dataclasses import dataclass from multiprocessing.process import BaseProcess -from typing import List, Tuple, Type, Union +from typing import List, Tuple, Type import zmq import zmq.asyncio @@ -20,7 +20,7 @@ from vllm.v1.core.scheduler import Scheduler from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs, EngineCoreProfile, EngineCoreRequest, - EngineCoreRequestType) + EngineCoreRequestType, EngineCoreRequestUnion) from vllm.v1.engine.mm_input_mapper import MMInputMapperServer from vllm.v1.executor.abstract import Executor from vllm.v1.request import Request, RequestStatus @@ -97,8 +97,10 @@ def add_request(self, request: EngineCoreRequest): # Note that the cache here is mirrored with the client side of the # MM mapper, so anything that has a hash must have a HIT cache # entry here as well. - request.mm_inputs = self.mm_input_mapper_server.process_inputs( - request.mm_inputs, request.mm_hashes) + assert request.mm_inputs is not None + request.mm_inputs, request.mm_hashes = ( + self.mm_input_mapper_server.process_inputs( + request.mm_inputs, request.mm_hashes)) req = Request.from_engine_core_request(request) @@ -128,7 +130,7 @@ def step(self) -> List[EngineCoreOutput]: def shutdown(self): self.model_executor.shutdown() - def profile(self, is_start=True): + def profile(self, is_start: bool = True): self.model_executor.profile(is_start) @@ -161,8 +163,8 @@ def __init__( # and to overlap some serialization/deserialization with the # model forward pass. # Threads handle Socket <-> Queues and core_busy_loop uses Queue. - self.input_queue = queue.Queue() - self.output_queue = queue.Queue() + self.input_queue: queue.Queue[EngineCoreRequestUnion] = queue.Queue() + self.output_queue: queue.Queue[List[EngineCoreOutput]] = queue.Queue() threading.Thread(target=self.process_input_socket, args=(input_path, ), daemon=True).start() @@ -318,9 +320,7 @@ def _log_stats(self): self._last_logging_time = now - def _handle_client_request( - self, request: Union[EngineCoreRequest, EngineCoreProfile, - List[str]]) -> None: + def _handle_client_request(self, request: EngineCoreRequestUnion) -> None: """Handle EngineCoreRequest or EngineCoreABORT from Client.""" if isinstance(request, EngineCoreRequest): diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index e0bfe1b93b360..ff25a9b2e9cac 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -1,6 +1,6 @@ import atexit import os -from typing import List, Union +from typing import List, Optional import msgspec import zmq @@ -10,8 +10,9 @@ from vllm.utils import get_open_zmq_ipc_path, kill_process_tree from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs, EngineCoreProfile, EngineCoreRequest, - EngineCoreRequestType) -from vllm.v1.engine.core import EngineCore, EngineCoreProc + EngineCoreRequestType, EngineCoreRequestUnion) +from vllm.v1.engine.core import (EngineCore, EngineCoreProc, + EngineCoreProcHandle) from vllm.v1.serial_utils import PickleEncoder logger = init_logger(__name__) @@ -59,7 +60,7 @@ def get_output(self) -> List[EngineCoreOutput]: def add_request(self, request: EngineCoreRequest) -> None: raise NotImplementedError - async def profile(self, is_start=True) -> None: + def profile(self, is_start: bool = True) -> None: raise NotImplementedError def abort_requests(self, request_ids: List[str]) -> None: @@ -71,6 +72,9 @@ async def get_output_async(self) -> List[EngineCoreOutput]: async def add_request_async(self, request: EngineCoreRequest) -> None: raise NotImplementedError + async def profile_async(self, is_start: bool = True) -> None: + raise NotImplementedError + async def abort_requests_async(self, request_ids: List[str]) -> None: raise NotImplementedError @@ -105,7 +109,7 @@ def shutdown(self): def __del__(self): self.shutdown() - def profile(self, is_start=True) -> None: + def profile(self, is_start: bool = True) -> None: self.engine_core.profile(is_start) @@ -133,7 +137,10 @@ def __init__( self.decoder = msgspec.msgpack.Decoder(EngineCoreOutputs) # ZMQ setup. - self.ctx = (zmq.asyncio.Context() if asyncio_mode else zmq.Context()) + if asyncio_mode: + self.ctx = zmq.asyncio.Context() + else: + self.ctx = zmq.Context() # type: ignore[attr-defined] # Path for IPC. ready_path = get_open_zmq_ipc_path() @@ -149,11 +156,13 @@ def __init__( self.input_socket.bind(input_path) # Start EngineCore in background process. + self.proc_handle: Optional[EngineCoreProcHandle] self.proc_handle = EngineCoreProc.make_engine_core_process( *args, - input_path=input_path, - output_path=output_path, - ready_path=ready_path, + input_path= + input_path, # type: ignore[misc] # MyPy incorrectly flags duplicate keywords + output_path=output_path, # type: ignore[misc] + ready_path=ready_path, # type: ignore[misc] **kwargs, ) atexit.register(self.shutdown) @@ -204,10 +213,8 @@ def get_output(self) -> List[EngineCoreOutput]: engine_core_outputs = self.decoder.decode(frame.buffer).outputs return engine_core_outputs - def _send_input( - self, request_type: EngineCoreRequestType, - request: Union[EngineCoreRequest, EngineCoreProfile, - List[str]]) -> None: + def _send_input(self, request_type: EngineCoreRequestType, + request: EngineCoreRequestUnion) -> None: # (RequestType, SerializedRequest) msg = (request_type.value, self.encoder.encode(request)) @@ -219,7 +226,7 @@ def add_request(self, request: EngineCoreRequest) -> None: def abort_requests(self, request_ids: List[str]) -> None: self._send_input(EngineCoreRequestType.ABORT, request_ids) - def profile(self, is_start=True) -> None: + def profile(self, is_start: bool = True) -> None: self._send_input(EngineCoreRequestType.PROFILE, EngineCoreProfile(is_start)) @@ -237,10 +244,8 @@ async def get_output_async(self) -> List[EngineCoreOutput]: return engine_core_outputs - async def _send_input( - self, request_type: EngineCoreRequestType, - request: Union[EngineCoreRequest, EngineCoreProfile, - List[str]]) -> None: + async def _send_input(self, request_type: EngineCoreRequestType, + request: EngineCoreRequestUnion) -> None: msg = (request_type.value, self.encoder.encode(request)) await self.input_socket.send_multipart(msg, copy=False) @@ -252,6 +257,6 @@ async def abort_requests_async(self, request_ids: List[str]) -> None: if len(request_ids) > 0: await self._send_input(EngineCoreRequestType.ABORT, request_ids) - async def profile(self, is_start=True) -> None: + async def profile_async(self, is_start: bool = True) -> None: await self._send_input(EngineCoreRequestType.PROFILE, EngineCoreProfile(is_start)) diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index 6249d60199a62..02f34e2b54dd5 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Dict, Iterable, List, Optional, Tuple +from typing import Dict, Iterable, List, Optional, Tuple, Union from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger @@ -97,7 +97,7 @@ def add_tokens( self, new_token_ids: List[int], finish_reason: Optional[str], - stop_reason: Optional[str], + stop_reason: Optional[Union[int, str, None]], ) -> Optional[RequestOutput]: """ Update RequestState for the request_id by: diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index c02494897b41f..15dedbd0f9529 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -103,7 +103,8 @@ def from_engine_args( multiprocess_mode=enable_multiprocessing) @classmethod - def _get_executor_cls(cls, vllm_config: VllmConfig): + def _get_executor_cls(cls, vllm_config: VllmConfig) -> Type[Executor]: + executor_class: Type[Executor] distributed_executor_backend = ( vllm_config.parallel_config.distributed_executor_backend) if distributed_executor_backend == "mp": diff --git a/vllm/v1/engine/mm_input_mapper.py b/vllm/v1/engine/mm_input_mapper.py index 58ee29bedb201..cca27c2218af7 100644 --- a/vllm/v1/engine/mm_input_mapper.py +++ b/vllm/v1/engine/mm_input_mapper.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple import PIL from blake3 import blake3 @@ -42,14 +42,14 @@ def __init__( model_config) self.mm_registry.init_mm_limits_per_prompt(model_config) - self.mm_cache = LRUDictCache(MM_CACHE_SIZE) + self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE) # DEBUG: Set to None to disable self.mm_debug_cache_hit_ratio_steps = None self.mm_cache_hits = 0 self.mm_cache_total = 0 - def cache_hit_ratio(self, steps) -> float: + def cache_hit_ratio(self, steps): if self.mm_cache_total > 0 and self.mm_cache_total % steps == 0: logger.debug("MMInputMapper: cache_hit_ratio = %.2f ", self.mm_cache_hits / self.mm_cache_total) @@ -60,7 +60,7 @@ def process_inputs( mm_hashes: Optional[List[str]], mm_processor_kwargs: Optional[Dict[str, Any]], precomputed_mm_inputs: Optional[List[MultiModalKwargs]], - ) -> List[MultiModalKwargs]: + ) -> Tuple[List[MultiModalKwargs], Optional[List[str]]]: if precomputed_mm_inputs is None: image_inputs = mm_data["image"] if not isinstance(image_inputs, list): @@ -72,6 +72,7 @@ def process_inputs( # Check if hash is enabled use_hash = mm_hashes is not None if use_hash: + assert mm_hashes is not None assert num_inputs == len( mm_hashes), "num_inputs = {} len(mm_hashes) = {}".format( num_inputs, len(mm_hashes)) @@ -79,7 +80,7 @@ def process_inputs( # Process each image input separately, so that later we can schedule # them in a fine-grained manner. # Apply caching (if enabled) and reuse precomputed inputs (if provided) - ret_hashes = [] if use_hash else None + ret_hashes: Optional[List[str]] = [] if use_hash else None ret_inputs: List[MultiModalKwargs] = [] for input_id in range(num_inputs): if self.mm_debug_cache_hit_ratio_steps is not None: @@ -88,6 +89,7 @@ def process_inputs( mm_hash = None mm_input = None if use_hash: + assert mm_hashes is not None mm_hash = mm_hashes[input_id] mm_input = self.mm_cache.get(mm_hash) @@ -105,12 +107,15 @@ def process_inputs( if use_hash: # Add to cache + assert mm_hash is not None self.mm_cache.put(mm_hash, mm_input) else: self.mm_cache_hits += 1 mm_input = None # Avoids sending mm_input to Server if use_hash: + assert mm_hash is not None + assert ret_hashes is not None ret_hashes.append(mm_hash) ret_inputs.append(mm_input) @@ -120,17 +125,18 @@ def process_inputs( class MMInputMapperServer: def __init__(self, ): - self.mm_cache = LRUDictCache(MM_CACHE_SIZE) + self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE) def process_inputs( self, mm_inputs: List[Optional[MultiModalKwargs]], - mm_hashes: List[Optional[str]], + mm_hashes: List[str], ) -> List[MultiModalKwargs]: assert len(mm_inputs) == len(mm_hashes) full_mm_inputs = [] for mm_input, mm_hash in zip(mm_inputs, mm_hashes): + assert mm_hash is not None if mm_input is None: mm_input = self.mm_cache.get(mm_hash) assert mm_input is not None diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 903996bad3726..679bf8e25e9ca 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -56,7 +56,7 @@ def process_inputs( request_id: str, prompt: PromptType, params: Union[SamplingParams, PoolingParams], - arrival_time: float, + arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 9cd267581ad18..564d0447f15a6 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Dict, Optional, Tuple +from typing import Tuple from vllm.config import VllmConfig from vllm.v1.outputs import ModelRunnerOutput @@ -28,7 +28,7 @@ def execute_model( raise NotImplementedError @abstractmethod - def profile(self, is_start=True): + def profile(self, is_start: bool = True): raise NotImplementedError @abstractmethod @@ -38,11 +38,3 @@ def shutdown(self): @abstractmethod def check_health(self) -> None: raise NotImplementedError - - @abstractmethod - def collective_rpc(self, - method: str, - timeout: Optional[float] = None, - args: Tuple = (), - kwargs: Optional[Dict] = None) -> []: - raise NotImplementedError diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 14384a730ceec..17441dacdc5cf 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -7,7 +7,7 @@ from dataclasses import dataclass from enum import Enum, auto from multiprocessing.process import BaseProcess -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import zmq @@ -21,6 +21,7 @@ from vllm.logger import init_logger from vllm.utils import (get_distributed_init_method, get_open_port, get_open_zmq_ipc_path) +from vllm.v1.executor.abstract import Executor from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.utils import make_zmq_socket from vllm.worker.worker_base import WorkerWrapperBase @@ -31,7 +32,7 @@ POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000 -class MultiprocExecutor: +class MultiprocExecutor(Executor): def __init__(self, vllm_config: VllmConfig) -> None: # Call self.shutdown at exit to clean up @@ -103,7 +104,7 @@ def collective_rpc(self, method: str, timeout: Optional[float] = None, args: Tuple = (), - kwargs: Optional[Dict] = None) -> []: + kwargs: Optional[Dict] = None) -> List[Any]: """ Execute an RPC call on workers. @@ -125,7 +126,7 @@ def collective_rpc(self, responses = [None] * self.world_size for w in self.workers: - dequeue_timeout = timeout - (time.monotonic() - start_time() + dequeue_timeout = timeout - (time.monotonic() - start_time ) if timeout is not None else None status, result = w.worker_response_mq.dequeue( timeout=dequeue_timeout) @@ -153,7 +154,7 @@ def execute_model( args=(scheduler_output, ))[0] return model_output - def profile(self, is_start=True): + def profile(self, is_start: bool = True): self.collective_rpc("profile", args=(is_start, )) return @@ -185,7 +186,6 @@ def wait_for_termination(procs, timeout): p.kill() self._cleanup_sockets() - self.workers = None def _cleanup_sockets(self): for w in self.workers: @@ -200,7 +200,8 @@ def shutdown(self): # again atexit.unregister(self.shutdown) """Properly shut down the executor and its workers""" - if (hasattr(self, 'workers') and self.workers is not None): + if getattr(self, 'shutting_down', False): + self.shutting_down = True for w in self.workers: #TODO: not sure if needed w.worker_response_mq = None self._ensure_worker_termination() diff --git a/vllm/v1/executor/uniproc_executor.py b/vllm/v1/executor/uniproc_executor.py index 9b1d9a40950c6..be058318de58b 100644 --- a/vllm/v1/executor/uniproc_executor.py +++ b/vllm/v1/executor/uniproc_executor.py @@ -4,13 +4,14 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.utils import get_distributed_init_method, get_ip, get_open_port +from vllm.v1.executor.abstract import Executor from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.worker.gpu_worker import Worker logger = init_logger(__name__) -class UniprocExecutor: +class UniprocExecutor(Executor): def __init__(self, vllm_config: VllmConfig) -> None: self.vllm_config = vllm_config @@ -25,7 +26,7 @@ def __init__(self, vllm_config: VllmConfig) -> None: self.prompt_adapter_config = vllm_config.prompt_adapter_config self.observability_config = vllm_config.observability_config - self.worker = self._create_worker() + self.worker: Worker = self._create_worker() self.worker.initialize() self.worker.load_model() @@ -75,7 +76,7 @@ def profile(self, is_start: bool = True): self.worker.profile(is_start) def shutdown(self): - self.worker = None + pass def check_health(self) -> None: # UniprocExecutor will always be healthy as long as diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 6bc1e4d5c769f..1737d096e811d 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -52,10 +52,9 @@ def __init__( else: self.mm_positions = [] # Output of the mm input mapper (e.g., image tensors). + self.mm_inputs: List[MultiModalKwargs] = [] if self.inputs.multi_modal_inputs: self.mm_inputs = self.inputs.multi_modal_inputs - else: - self.mm_inputs: List[MultiModalKwargs] = [] @classmethod def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 6ecf20e717ca3..5f327d7066830 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -1,6 +1,8 @@ from collections import OrderedDict +from collections.abc import Sequence from contextlib import contextmanager -from typing import Any, Generic, Iterator, List, TypeVar, overload +from typing import (Any, Generic, Iterator, List, Optional, TypeVar, Union, + overload) import zmq @@ -11,7 +13,7 @@ T = TypeVar("T") -class ConstantList(Generic[T]): +class ConstantList(Generic[T], Sequence): def __init__(self, x: List[T]) -> None: self._x = x @@ -34,29 +36,33 @@ def remove(self, item): def clear(self): raise Exception("Cannot clear a constant list") - def index(self, item): - return self._x.index(item) + def index(self, + item: T, + start: int = 0, + stop: Optional[int] = None) -> int: + return self._x.index(item, start, + stop if stop is not None else len(self._x)) @overload - def __getitem__(self, item) -> T: + def __getitem__(self, item: int) -> T: ... @overload def __getitem__(self, s: slice, /) -> List[T]: ... - def __getitem__(self, item): + def __getitem__(self, item: Union[int, slice]) -> Union[T, List[T]]: return self._x[item] @overload - def __setitem__(self, item, value): + def __setitem__(self, item: int, value: T): ... @overload - def __setitem__(self, s: slice, value, /): + def __setitem__(self, s: slice, value: T, /): ... - def __setitem__(self, item, value): + def __setitem__(self, item: Union[int, slice], value: Union[T, List[T]]): raise Exception("Cannot set item in a constant list") def __delitem__(self, item): @@ -73,10 +79,12 @@ def __len__(self): @contextmanager -def make_zmq_socket(path: str, type: Any) -> Iterator[zmq.Socket]: +def make_zmq_socket( + path: str, + type: Any) -> Iterator[zmq.Socket]: # type: ignore[name-defined] """Context manager for a ZMQ socket""" - ctx = zmq.Context() + ctx = zmq.Context() # type: ignore[attr-defined] try: socket = ctx.socket(type) @@ -96,20 +104,24 @@ def make_zmq_socket(path: str, type: Any) -> Iterator[zmq.Socket]: ctx.destroy(linger=0) -class LRUDictCache: +K = TypeVar('K') +V = TypeVar('V') + + +class LRUDictCache(Generic[K, V]): def __init__(self, size: int): - self.cache = OrderedDict() + self.cache: OrderedDict[K, V] = OrderedDict() self.size = size - def get(self, key, default=None): + def get(self, key: K, default=None) -> V: if key not in self.cache: return default self.cache.move_to_end(key) return self.cache[key] - def put(self, key, value): + def put(self, key: K, value: V): self.cache[key] = value self.cache.move_to_end(key) if len(self.cache) > self.size: diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 9046b37f60005..5c113c74778df 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -215,6 +215,7 @@ def condense(self, empty_req_indices: List[int]) -> None: # Swap the states. req_id = self.req_ids[last_req_index] + assert req_id is not None self.req_ids[empty_index] = req_id self.req_ids[last_req_index] = None self.req_id_to_index[req_id] = empty_index diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f24942068d1f8..abcd4b007a326 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,6 +1,6 @@ import gc import time -from typing import TYPE_CHECKING, Dict, List, Tuple +from typing import TYPE_CHECKING, Dict, List, Tuple, cast import numpy as np import torch @@ -193,9 +193,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: req_ids_to_add: List[str] = [] # Add new requests to the cached states. - for req_data in scheduler_output.scheduled_new_reqs: - req_id = req_data.req_id - sampling_params = req_data.sampling_params + for new_req_data in scheduler_output.scheduled_new_reqs: + req_id = new_req_data.req_id + sampling_params = new_req_data.sampling_params if sampling_params.sampling_type == SamplingType.RANDOM_SEED: generator = torch.Generator(device=self.device) generator.manual_seed(sampling_params.seed) @@ -204,25 +204,25 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.requests[req_id] = CachedRequestState( req_id=req_id, - prompt_token_ids=req_data.prompt_token_ids, - prompt=req_data.prompt, - mm_inputs=req_data.mm_inputs, - mm_positions=req_data.mm_positions, + prompt_token_ids=new_req_data.prompt_token_ids, + prompt=new_req_data.prompt, + mm_inputs=new_req_data.mm_inputs, + mm_positions=new_req_data.mm_positions, sampling_params=sampling_params, generator=generator, - block_ids=req_data.block_ids, - num_computed_tokens=req_data.num_computed_tokens, + block_ids=new_req_data.block_ids, + num_computed_tokens=new_req_data.num_computed_tokens, output_token_ids=[], ) req_ids_to_add.append(req_id) # Update the cached states of the resumed requests. - for req_data in scheduler_output.scheduled_resumed_reqs: - req_id = req_data.req_id + for res_req_data in scheduler_output.scheduled_resumed_reqs: + req_id = res_req_data.req_id req_state = self.requests[req_id] - req_state.block_ids = req_data.block_ids - req_state.num_computed_tokens = req_data.num_computed_tokens + req_state.block_ids = res_req_data.block_ids + req_state.num_computed_tokens = res_req_data.num_computed_tokens req_ids_to_add.append(req_id) # Add the new or resumed requests to the persistent batch. @@ -259,6 +259,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): num_scheduled_tokens = [] max_num_scheduled_tokens = 0 for req_id in self.input_batch.req_ids[:num_reqs]: + assert req_id is not None num_tokens = scheduler_output.num_scheduled_tokens[req_id] num_scheduled_tokens.append(num_tokens) max_num_scheduled_tokens = max(max_num_scheduled_tokens, @@ -373,7 +374,7 @@ def _execute_encoder(self, scheduler_output: "SchedulerOutput"): # Batch the multi-modal inputs. mm_inputs: List[MultiModalKwargs] = [] - req_input_ids: List[Tuple[int, int]] = [] + req_input_ids: List[Tuple[str, int]] = [] for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): req_state = self.requests[req_id] for input_id in encoder_input_ids: @@ -406,6 +407,7 @@ def _gather_encoder_outputs( encoder_outputs: List[torch.Tensor] = [] num_reqs = self.input_batch.num_reqs for req_id in self.input_batch.req_ids[:num_reqs]: + assert req_id is not None num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ req_id] req_state = self.requests[req_id] @@ -514,6 +516,7 @@ def execute_model( # the requests one by one. Optimize. num_reqs = self.input_batch.num_reqs for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): + assert req_id is not None req_state = self.requests[req_id] seq_len = (req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]) @@ -539,8 +542,15 @@ def execute_model( logprobs = None else: logprobs = sampler_output.logprobs.cpu() + + # num_reqs entries should be non-None + assert all( + req_id is not None for req_id in + self.input_batch.req_ids[:num_reqs]), "req_ids contains None" + req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs]) + model_runner_output = ModelRunnerOutput( - req_ids=self.input_batch.req_ids[:num_reqs], + req_ids=req_ids, req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=sampled_token_ids, logprob_token_ids_cpu=logprob_token_ids, diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 49e415ab72e0b..33491f700de10 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -204,7 +204,7 @@ def execute_model( return output if self.rank == 0 else None return output - def profile(self, is_start=True): + def profile(self, is_start: bool = True): if self.profiler is None: raise RuntimeError("Profiler is not enabled.") if is_start: