From fe1e1b4180afeb53b8bc5515da56231b28fd3a5e Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Sun, 27 Oct 2024 22:34:17 +0000 Subject: [PATCH] formatting tweaks --- vllm/v1/engine/llm_engine.py | 49 +++++++++++++++++++------------ vllm/v1/engine/llm_engine_core.py | 47 ++++++++++++++--------------- 2 files changed, 54 insertions(+), 42 deletions(-) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 991cc52151e59..2fdcde1130dc3 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -1,5 +1,6 @@ import time -from typing import Any, Dict, List, Mapping, Optional, Tuple, Type, Union +from typing import (Any, Dict, Iterable, List, Mapping, Optional, Tuple, Type, + Union) import msgspec import zmq @@ -152,10 +153,10 @@ def __init__( self.ctx = zmq.Context() # type: ignore[attr-defined] - self.from_core_ipc_path = get_open_zmq_ipc_path() self.to_core_ipc_path = get_open_zmq_ipc_path() # Get output (EngineCoreOutput) from LLMEngineCore. + self.from_core_ipc_path = get_open_zmq_ipc_path() self.from_core = self.ctx.socket(zmq.constants.PULL) self.from_core.bind(self.from_core_ipc_path) @@ -163,6 +164,8 @@ def __init__( self.to_core = self.ctx.socket(zmq.constants.PUSH) self.to_core.bind(self.to_core_ipc_path) + # TODO: startup engine core. + @classmethod def from_engine_args( cls, @@ -247,6 +250,28 @@ def add_request( copy=False, flags=zmq.NOBLOCK) + def step(self) -> List[RequestOutput]: + # NOTE: This method returns an empty list if the EngineCore + # step is running. This is not the end of the generation process. + if self.from_core.poll(timeout=0) != 0: + frames = self.from_core.recv_multipart(copy=False) + engine_core_outputs = self.decoder(frames[0].buffer).outputs + request_outputs = self.detokenizer.step(engine_core_outputs) + return request_outputs + + return [] + + def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: + # TODO: send to EngineCore + # TODO: send to Deoktenizer + pass + + def check_health(self) -> None: + if self.tokenizer: + self.tokenizer.check_health() + # self.model_executor.check_health() + # TODO: send health check to EngineCore. + def _make_requests( self, request_id: str, @@ -257,6 +282,7 @@ def _make_requests( prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> Tuple[DetokenizerRequest, EngineCoreRequest]: + # Process inputs. preprocessed_inputs = self.input_preprocessor.preprocess( prompt, request_id=request_id, @@ -272,7 +298,7 @@ def _make_requests( sampling_params.update_from_generation_config( self.generation_config_fields, eos_token_id) - # Make input to Detokenizer + # Make Request for Detokenizer. detokenizer_request = DetokenizerRequest( request_id, processed_inputs.prompt, processed_inputs.prompt_token_ids, @@ -280,28 +306,13 @@ def _make_requests( sampling_params.spaces_between_special_tokens, sampling_params.output_kind) - # Make input to EngineCore + # Make Request for EngineCore. engine_core_request = EngineCoreRequest(request_id, processed_inputs, sampling_params, eos_token_id, arrival_time, lora_request) return detokenizer_request, engine_core_request - def step(self) -> List[RequestOutput]: - if self.from_core.poll(timeout=0) != 0: - frames = self.from_core.recv_multipart(copy=False) - engine_core_outputs = self.decoder(frames[0].buffer).outputs - request_outputs = self.detokenizer.step(engine_core_outputs) - return request_outputs - - return [] - - def check_health(self) -> None: - if self.tokenizer: - self.tokenizer.check_health() - # self.model_executor.check_health() - # TODO: send health check to EngineCore. - def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs, EncoderDecoderLLMInputs]): prompt_ids = inputs.get("prompt_token_ids") diff --git a/vllm/v1/engine/llm_engine_core.py b/vllm/v1/engine/llm_engine_core.py index 497d371eecae6..38718d3798068 100644 --- a/vllm/v1/engine/llm_engine_core.py +++ b/vllm/v1/engine/llm_engine_core.py @@ -38,7 +38,6 @@ def __init__( observability_config: Optional[ObservabilityConfig], prompt_adapter_config: Optional[PromptAdapterConfig], ): - self.input_path = input_path self.output_path = output_path self.executor_class = executor_class @@ -58,14 +57,15 @@ def run(self): self.msgpack_encoder = msgspec.msgpack.Encoder() self.ctx = zmq.Context() # type: ignore[attr-defined] - # Get input (new Requests) from the LLMEngine. + # Get EngineCoreRequests from the LLMEngine. self.input_socket = self.ctx.socket(zmq.constants.PULL) self.input_socket.connect(self.input_path) - # Send output (EngineCoreOutput) to the LLMEngine. + # Send EngineCoreOutput to the LLMEngine. self.output_socket = self.ctx.socket(zmq.constants.PUSH) self.output_socket.connect(self.output_path) + # Setup Model. self.model_executor = self.executor_class( model_config=self.model_config, cache_config=self.cache_config, @@ -79,18 +79,20 @@ def run(self): observability_config=self.observability_config, ) + # Setup KV Caches. + # NOTE: the cache_config isn updated with the numbers of GPU and CPU + # blocks, which are profiled in the distributed executor. self._initialize_kv_caches() - # NOTE: the cache_config here have been updated with the numbers of - # GPU and CPU blocks, which are profiled in the distributed executor. + # Setup Scheduler. self.scheduler = Scheduler(self.scheduler_config, self.cache_config, self.lora_config) - # TODO: add heartbeat thread. - - # Run core loop. + # Kickoff the busy loop. self._run_busy_loop() + # TODO: add heartbeat thread. + def _initialize_kv_caches(self) -> None: num_gpu_blocks, _ = self.model_executor.determine_num_available_blocks( ) @@ -124,6 +126,18 @@ def _run_busy_loop(self): # Send outputs back to the LLMEngine. self._send_outputs(outputs) + def _step(self) -> Optional[List[EngineCoreOutputs]]: + """Schedule, execute, and make output.""" + + if not self.scheduler.has_unfinished_requests(): + return None + + scheduler_output = self.scheduler.schedule() + output = self.model_executor.execute_model(scheduler_output) + engine_core_outputs = self.scheduler.update_from_output( + scheduler_output, output) + return engine_core_outputs + def _handle_new_input(self): """Handle new input from the LLMEngine.""" try: @@ -142,18 +156,6 @@ def _handle_new_input(self): # TODO: handle gracefully raise e - def _step(self) -> Optional[List[EngineCoreOutputs]]: - """Schedule, execute, and make output.""" - - if not self.scheduler.has_unfinished_requests(): - return None - - scheduler_output = self.scheduler.schedule() - output = self.model_executor.execute_model(scheduler_output) - engine_core_outputs = self.scheduler.update_from_output( - scheduler_output, output) - return engine_core_outputs - def _send_outputs( self, engine_core_outputs: Optional[List[EngineCoreOutput]]) -> None: @@ -162,9 +164,8 @@ def _send_outputs( if engine_core_outputs is None: return - outputs_serialized = self.msgpack_encoder.encode( - EngineCoreOutputs(data=engine_core_outputs)) - + outputs = EngineCoreOutputs(data=engine_core_outputs) + outputs_serialized = self.msgpack_encoder.encode(outputs) self.output_socket.send_multipart((outputs_serialized, ), copy=False, flags=zmq.NOBLOCK)