diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 41659ff62747d..9b048ee27dca4 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -1,12 +1,13 @@ from collections import deque from dataclasses import dataclass -from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Deque, Dict, Iterable, List, Optional, Set, Union from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.logger import init_logger from vllm.multimodal import MultiModalDataDict from vllm.sampling_params import SamplingParams from vllm.v1.core.kv_cache_manager import KVCacheManager +from vllm.v1.engine import EngineCoreOutput from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus @@ -227,13 +228,13 @@ def update_from_output( self, scheduler_output: "SchedulerOutput", model_runner_output: "ModelRunnerOutput", - ) -> List[Tuple[Request, int]]: + ) -> List[EngineCoreOutput]: + # NOTE(robertgshaw2): Should this method be in EngineCore instead? # NOTE(woosuk): This method doesn't consider speculative decoding. sampled_token_ids = model_runner_output.sampled_token_ids_cpu.tolist() num_scheduled_tokens = scheduler_output.num_scheduled_tokens new_running: List[Request] = [] - # (request, num_sampled_tokens) - sampled: List[Tuple[Request, int]] = [] + engine_core_outputs: List[EngineCoreOutput] = [] for request in self.running: req_id = request.request_id request.num_computed_tokens += num_scheduled_tokens[req_id] @@ -247,17 +248,30 @@ def update_from_output( # generates at most one token at each step. token_id = sampled_token_ids[req_index] request.output_token_ids.append(token_id) - sampled.append((request, 1)) + num_new_tokens = 1 + # TODO: Update the KV cache manager for prefix caching. - # Check if the request is finished. + # Check for stop and update request state. + # This must be called before me make the EngineCoreOutput. stopped = self._check_stop(request) + + # Add EngineCoreOutput for this Request. + output = EngineCoreOutput( + request_id=req_id, + new_token_ids=request.output_token_ids[-num_new_tokens:], + finished=request.is_finished(), + finish_reason=request.get_finished_reason(), + stop_reason=request.stop_reason) + engine_core_outputs.append(output) + + # Breakout of the loop. if stopped: continue new_running.append(request) self.running = new_running - return sampled + return engine_core_outputs def _check_stop(self, request: Request) -> bool: if (request.num_tokens >= self.max_model_len diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index e69de29bb2d1d..1c3f9936a2bb5 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -0,0 +1,51 @@ +from dataclasses import dataclass +from typing import List, Optional, Union + +import msgspec + +from vllm.lora.request import LoRARequest +from vllm.sampling_params import RequestOutputKind, SamplingParams + +LLM_ENGINE_CORE_READY_STR = "READY" + + +@dataclass +class DetokenizerRequest: + + request_id: str + prompt: Optional[str] + prompt_token_ids: List[int] + skip_special_tokens: bool + spaces_between_special_tokens: bool + output_kind: RequestOutputKind + + +class EngineCoreRequest(msgspec.Struct): + + # NOTE: prompt and prompt_token_ids should be DecoderOnlyInput, + # but this is not playing well with msgspec due to circular + # imports and weird typing we have going on in data.py + + request_id: str + prompt: Optional[str] + prompt_token_ids: List[int] + sampling_params: SamplingParams + eos_token_id: Optional[int] + arrival_time: float + lora_request: Optional[LoRARequest] + + +@dataclass +class EngineCoreOutput: + + request_id: str + new_token_ids: List[int] + finished: bool + finish_reason: Optional[str] = None + stop_reason: Union[int, str, None] = None + + +class EngineCoreOutputs(msgspec.Struct): + + # [num_reqs] + outputs: List[EngineCoreOutput] diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py new file mode 100644 index 0000000000000..cbfb1073d6247 --- /dev/null +++ b/vllm/v1/engine/detokenizer.py @@ -0,0 +1,216 @@ +from dataclasses import dataclass +from typing import Dict, List, Optional + +from vllm.logger import init_logger +from vllm.outputs import CompletionOutput, RequestOutput +from vllm.sampling_params import RequestOutputKind +from vllm.transformers_utils.detokenizer_utils import ( + AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally) +from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.v1.engine import DetokenizerRequest, EngineCoreOutput + +logger = init_logger(__name__) + + +@dataclass +class DetokenizerRequestState: + + # Generation data + output_text: str + tokens: List[str] + token_ids: List[int] + + # Metadata for incremental detokenization + prefix_offset: int + read_offset: int + + # Parameters for detokenization + skip_special_tokens: bool + spaces_between_special_tokens: bool + output_kind: RequestOutputKind + + # Request output (Cached + updated incrementally) + request_output: RequestOutput + + @classmethod + def from_new_request(cls, tokenizer: AnyTokenizer, + request: DetokenizerRequest): + + tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens( + tokenizer=tokenizer, + prompt_ids=request.prompt_token_ids, + skip_special_tokens=request.skip_special_tokens, + ) + + request_output = cls._initialize_request_output( + request.request_id, + request.prompt, + request.prompt_token_ids, + ) + + return cls( + output_text="", + tokens=tokens, + # Detokenizer mutates this list, so need a unique copy. + token_ids=request.prompt_token_ids.copy(), + prefix_offset=prefix_offset, + read_offset=read_offset, + skip_special_tokens=request.skip_special_tokens, + spaces_between_special_tokens=request. + spaces_between_special_tokens, + output_kind=request.output_kind, + request_output=request_output) + + @staticmethod + def _initialize_request_output( + request_id: str, prompt: str, + prompt_token_ids: List[int]) -> RequestOutput: + """Initialize a new RequestOutput object.""" + + # TODO: Support `n` > 1. + completion_output = CompletionOutput( + index=0, + text="", + token_ids=[], + cumulative_logprob=None, + logprobs=None, # TODO + finish_reason=None, + stop_reason=None, + lora_request=None, + ) + + return RequestOutput( + request_id=request_id, + prompt=prompt, + prompt_token_ids=prompt_token_ids, + prompt_logprobs=None, # TODO + outputs=[completion_output], + finished=False, + metrics=None, + lora_request=None, + encoder_prompt=None, + encoder_prompt_token_ids=None, + ) + + +class Detokenizer: + + def __init__(self, tokenizer_name: str): + self.tokenizer = get_tokenizer(tokenizer_name) + + # Request id -> DetokenizerRequestState + self.request_states: Dict[str, DetokenizerRequestState] = {} + + def get_num_unfinished_requests(self): + return len(self.request_states) + + def has_unfinished_requests(self) -> bool: + return len(self.request_states) > 0 + + def add_request(self, request: DetokenizerRequest) -> None: + """Add new request to the Detokenizer.""" + + assert (request.request_id not in self.request_states) + + request_state = DetokenizerRequestState.from_new_request( + self.tokenizer, request) + self.request_states[request.request_id] = request_state + + def step( + self, encore_core_outputs: List[EngineCoreOutput] + ) -> List[RequestOutput]: + """Update the detokenizer state with the new tokens from EngineCore.""" + + request_outputs: List[RequestOutput] = [] + for engine_core_output in encore_core_outputs: + request_id = engine_core_output.request_id + request_state = self.request_states[request_id] + + # Detokenize and update state. + self._update_request_state( + tokenizer=self.tokenizer, + request_state=request_state, + new_token_ids=engine_core_output.new_token_ids, + finished=engine_core_output.finished, + finish_reason=engine_core_output.finish_reason, + stop_reason=engine_core_output.stop_reason, + ) + request_outputs.append(request_state.request_output) + + # Free completed requests. + if engine_core_output.finished: + self._free(request_id) + + # Send RequestOutputs to EngineClient. + return request_outputs + + def _free(self, request_id: str) -> None: + """Remove the request from the RequestState tracker.""" + + # TODO(robertgshaw2): should this be a del? + assert request_id in self.request_states + self.request_states.pop(request_id) + + @staticmethod + def _update_request_state( + tokenizer: AnyTokenizer, + request_state: DetokenizerRequestState, + new_token_ids: List[int], + finished: bool, + finish_reason: Optional[str], + stop_reason: Optional[str], + ) -> None: + """ + Update RequestState for the request_id by: + 1) Detokenize the new token ids incrementally. + 2) Update the RequestOutput with the new text. + """ + + # 1) Detokenize the new token ids incrementally. + # TODO(woosuk): This method becomes very inefficient when the number of + # new_token_ids is more than 1. We need to optimize this. + decoded_text = "" + for new_token_id in new_token_ids: + request_state.token_ids.append(new_token_id) + (new_tokens, new_decoded_token_text, prefix_offset, + read_offset) = detokenize_incrementally( + tokenizer=tokenizer, + all_input_ids=request_state.token_ids, + prev_tokens=request_state.tokens, + prefix_offset=request_state.prefix_offset, + read_offset=request_state.read_offset, + skip_special_tokens=request_state.skip_special_tokens, + spaces_between_special_tokens=request_state. + spaces_between_special_tokens, + ) + + request_state.tokens.extend(new_tokens) + request_state.prefix_offset = prefix_offset + request_state.read_offset = read_offset + request_state.output_text += new_decoded_token_text + + decoded_text += new_decoded_token_text + + # 2) Update the RequestOutput object with the new text. + request_output = request_state.request_output + completion_output = request_output.outputs[0] + if request_state.output_kind == RequestOutputKind.CUMULATIVE: + completion_output.text += decoded_text + completion_output.token_ids = request_state.token_ids + elif request_state.output_kind == RequestOutputKind.DELTA: + completion_output.text = decoded_text + num_prev_tokens = len(completion_output.token_ids) + completion_output.token_ids = request_state.token_ids[ + num_prev_tokens:] + elif request_state.output_kind == RequestOutputKind.FINAL_ONLY: + if finished: + completion_output.text = request_state.output_text + completion_output.token_ids = request_state.token_ids + else: + completion_output.text = "" + completion_output.token_ids = [] + + if finished: + completion_output.finish_reason = finish_reason + completion_output.stop_reason = stop_reason + request_output.finished = finished diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 072e52bcd686a..2fbf9ced3ae25 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -2,6 +2,9 @@ from typing import (Any, Dict, Iterable, List, Mapping, Optional, Tuple, Type, Union) +import msgspec +import zmq + from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, @@ -14,18 +17,20 @@ from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.outputs import CompletionOutput, RequestOutput +from vllm.outputs import RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import RequestOutputKind, SamplingParams +from vllm.sampling_params import SamplingParams from vllm.transformers_utils.config import try_get_generation_config from vllm.transformers_utils.tokenizer_group import ( BaseTokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import UsageContext -from vllm.v1.core.scheduler import Scheduler +from vllm.utils import get_open_zmq_ipc_path +from vllm.v1.engine import (LLM_ENGINE_CORE_READY_STR, DetokenizerRequest, + EngineCoreOutputs, EngineCoreRequest) +from vllm.v1.engine.detokenizer import Detokenizer +from vllm.v1.engine.llm_engine_core import LLMEngineCore from vllm.v1.executor.gpu_executor import GPUExecutor -from vllm.v1.request import Request, RequestStatus -from vllm.v1.tokenizer.detokenizer import Detokenizer, DetokenizerInputs from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -113,6 +118,7 @@ def __init__( ) self.model_config = model_config + assert self.model_config.task != "embedding" self.cache_config = cache_config self.lora_config = lora_config self.parallel_config = parallel_config @@ -142,53 +148,58 @@ def __init__( self.input_processor = input_registry.create_input_processor( model_config) - # Request id -> Request - self.requests: Dict[str, Request] = {} - # NOTE(woosuk): Now that the detokenizer works asynchronously, we need - # to keep track of how many steps each request has been lagged behind - # in terms of detokenization. - # Request id -> how many detokenizer steps the request should wait for. - self.num_lagged_steps: Dict[str, int] = {} - # OPTIMIZATION: Cache the request output and update it incrementally. - # This is used to avoid creating a new RequestOutput object every step. - # Request id -> RequestOutput - self.request_outputs: Dict[str, RequestOutput] = {} - - self.model_executor = executor_class( - model_config=model_config, - cache_config=cache_config, - parallel_config=parallel_config, - scheduler_config=scheduler_config, - device_config=device_config, - lora_config=lora_config, - speculative_config=speculative_config, - load_config=load_config, - prompt_adapter_config=prompt_adapter_config, + # IPC serialization / deserialization + self.encoder = msgspec.msgpack.Encoder() + self.decoder = msgspec.msgpack.Decoder(EngineCoreOutputs) + + # IPC path setup + self.ctx = zmq.Context() # type: ignore[attr-defined] + from_core_ipc_path = get_open_zmq_ipc_path() + to_core_ipc_path = get_open_zmq_ipc_path() + core_ready_ipc_path = get_open_zmq_ipc_path() + + # Get output (EngineCoreOutput) from LLMEngineCore. + self.from_core = self.ctx.socket(zmq.constants.PULL) + self.from_core.connect(from_core_ipc_path) + + # Send input (new Requests) to LLMEngineCore. + self.to_core = self.ctx.socket(zmq.constants.PUSH) + self.to_core.bind(to_core_ipc_path) + + # TODO: some of these configs will be mutated by + # EngineCore (in a separate process), e.g. cache_config. + # It would be better if we could prune down what is needed + # for EngineCore and prevent having two sources of truth + # or perhaps made these immutable? + self.engine_core = LLMEngineCore( + input_path=to_core_ipc_path, + output_path=from_core_ipc_path, + core_ready_path=core_ready_ipc_path, + executor_class=executor_class, + model_config=self.model_config, + cache_config=self.cache_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + load_config=self.load_config, + lora_config=self.lora_config, + speculative_config=self.speculative_config, observability_config=self.observability_config, + prompt_adapter_config=self.prompt_adapter_config, ) - assert self.model_config.task != "embedding" - self._initialize_kv_caches() + self.engine_core.start() + self._wait_for_engine_core(core_ready_ipc_path) - # Create the scheduler. - # NOTE: the cache_config here have been updated with the numbers of - # GPU and CPU blocks, which are profiled in the distributed executor. - self.scheduler = Scheduler(scheduler_config, cache_config, lora_config) + def __del__(self): + logger.debug("Shutting down LLMEngineCore.") - def _initialize_kv_caches(self) -> None: - num_gpu_blocks, _ = self.model_executor.determine_num_available_blocks( - ) + if hasattr(self, "ctx"): + self.ctx.destroy(linger=0) - if self.cache_config.num_gpu_blocks_override is not None: - num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override - logger.info( - "Overriding num_gpu_blocks=%d with " - "num_gpu_blocks_override=%d", num_gpu_blocks, - num_gpu_blocks_override) - num_gpu_blocks = num_gpu_blocks_override - - self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.num_cpu_blocks = 0 - self.model_executor.initialize_cache(num_gpu_blocks) + if hasattr(self, "engine_core"): + # TODO: do this more gracefully by sending a message? + # or at least using .terminate() + self.engine_core.kill() @classmethod def from_engine_args( @@ -211,6 +222,21 @@ def from_engine_args( ) return engine + def _wait_for_engine_core(self, ipc_path: str): + try: + ready_socket = self.ctx.socket(zmq.constants.PULL) + ready_socket.connect(ipc_path) + while ready_socket.poll(timeout=5000) == 0: + logger.debug("Waiting for LLMEngineCore to startup.") + if not self.engine_core.is_alive(): + raise RuntimeError("LLMEngineCore process failed to start") + + message = ready_socket.recv_string() + assert message == LLM_ENGINE_CORE_READY_STR + + finally: + ready_socket.close(linger=0) + def _init_tokenizer(self) -> BaseTokenizerGroup: return init_tokenizer_from_configs( model_config=self.model_config, @@ -229,35 +255,6 @@ def _verify_args(self) -> None: self.prompt_adapter_config.verify_with_model_config( self.model_config) - def _add_processed_request( - self, - request_id: str, - processed_inputs: Union[DecoderOnlyInputs, EncoderDecoderLLMInputs], - params: Union[SamplingParams, PoolingParams], - arrival_time: float, - lora_request: Optional[LoRARequest], - prompt_adapter_request: Optional[PromptAdapterRequest], - trace_headers: Optional[Mapping[str, str]] = None, - ) -> None: - assert prompt_adapter_request is None - assert trace_headers is None - self._validate_model_inputs(processed_inputs) - eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) - - # TODO(woosuk): Support embedding mode. - assert isinstance(params, SamplingParams) - sampling_params = params.clone() - sampling_params.update_from_generation_config( - self.generation_config_fields, eos_token_id) - - # TODO(woosuk): Check max_logprobs - # TODO(woosuk): Support encoder-decoder models. - req = Request(request_id, processed_inputs, params, eos_token_id, - arrival_time) - self.requests[request_id] = req - self.num_lagged_steps[request_id] = 0 - self.scheduler.add_request(req) - def stop_remote_worker_execution_loop(self) -> None: raise NotImplementedError("TP not implemented yet.") @@ -272,6 +269,17 @@ def add_request( prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, ) -> None: + """ + Add new request to the LLMEngine, in 3 steps: + 1) Processing the raw inputs + 2) Adding the request to the Detokenizer (running in this process) + 3) Adding the request to the EngineCore (running in other process) + """ + + # TODO(woosuk): Support embedding mode. + # TODO(woosuk): Check max_logprobs + # TODO(woosuk): Support encoder-decoder models. + if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") @@ -279,180 +287,96 @@ def add_request( arrival_time = time.time() assert priority == 0, "vLLM V1 does not support priority at the moment." - preprocessed_inputs = self.input_preprocessor.preprocess( - prompt, - request_id=request_id, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, - ) - processed_inputs = self.input_processor(preprocessed_inputs) + # 1) Process the inputs into the raw data needed for a request. + detokenizer_request, engine_core_request = self._make_requests( + request_id, prompt, params, arrival_time, lora_request, + prompt_adapter_request) - self._add_processed_request( - request_id=request_id, - processed_inputs=processed_inputs, - params=params, - arrival_time=arrival_time, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, - trace_headers=trace_headers, - ) + # 2) Add the request to Detokenizer (this process). + self.detokenizer.add_request(detokenizer_request) + + # 3) Add the request to EngineCore (separate process). + self.to_core.send_multipart( + (self.encoder.encode(engine_core_request), ), + copy=False, + flags=zmq.NOBLOCK) def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: - self.scheduler.finish_requests(request_id, - RequestStatus.FINISHED_ABORTED) - self._free_request(request_id) + # TODO: send to EngineCore + # TODO: send to Deoktenizer + pass + + # NOTE: a significant drawback of this design is now we have two + # trackers of running state (the Detokenizer and the Scheduler). + # Is there a better way to do this? + # Unfortunately we need need to send state over IPC? + # Maybe we could get back the scheduler state with EngineCoreOutput? + # Such that state is explicitly in sync rather than implicitly? def get_num_unfinished_requests(self) -> int: - """Gets the number of unfinished requests.""" - return len(self.requests) + return self.detokenizer.get_num_unfinished_requests() def has_unfinished_requests(self) -> bool: - """Returns True if there are unfinished requests.""" - return len(self.requests) > 0 + return self.detokenizer.has_unfinished_requests() def step(self) -> List[RequestOutput]: - # NOTE(woosuk): This method may return an empty list when the - # detokenizer is still processing the outputs. This should not be - # considered as the end of the generation process. - # FIXME(woosuk): Currently, the step method is inefficient because it - # creates RequestOutput objects for all running requests, while they - # may not be needed unless the output is streamed to the client. - if self.scheduler.has_unfinished_requests(): - scheduler_output = self.scheduler.schedule() - output = self.model_executor.execute_model(scheduler_output) - sampled = self.scheduler.update_from_output( - scheduler_output, output) - self.send_to_detokenizer(sampled) - req_outputs = self.recv_from_detokenizer() - return req_outputs - - def send_to_detokenizer(self, sampled: List[Tuple[Request, int]]) -> None: - inputs = DetokenizerInputs( - req_ids=[], - prompt_token_ids=[], - new_token_ids=[], - skip_special_tokens=[], - spaces_between_special_tokens=[], - free_req_ids=[], # TODO(woosuk): Implement freeing. - ) - for req, num_tokens in sampled: - inputs.req_ids.append(req.request_id) - if len(req.output_token_ids) == num_tokens: - # The request is first detokenized. - inputs.prompt_token_ids.append(req.prompt_token_ids) - else: - # The prompt token ids are already cached in the detokenizer. - inputs.prompt_token_ids.append([]) - inputs.new_token_ids.append(req.output_token_ids[-num_tokens:]) - inputs.skip_special_tokens.append( - req.sampling_params.skip_special_tokens) - inputs.spaces_between_special_tokens.append( - req.sampling_params.spaces_between_special_tokens) - - # Update the number of lagged steps. - self.num_lagged_steps[req.request_id] += 1 - self.detokenizer.send(inputs) - - def recv_from_detokenizer(self) -> List[RequestOutput]: - detokenizer_output = self.detokenizer.recv() - if detokenizer_output is None: - return [] - - req_outputs: List[RequestOutput] = [] - num_reqs = len(detokenizer_output.req_ids) - for i in range(num_reqs): - req_id = detokenizer_output.req_ids[i] - if req_id not in self.requests: - # The request has been aborted while the detokenizer was - # processing the outputs. - continue - - req = self.requests[req_id] - req.output_text += detokenizer_output.detokenized_texts[i] - - self.num_lagged_steps[req_id] -= 1 - finished = (self.num_lagged_steps[req_id] == 0 - and req.is_finished()) - req_output = self._make_request_output( - req, detokenizer_output.num_output_token_ids[i], - detokenizer_output.detokenized_texts[i], finished) - req_outputs.append(req_output) - - if finished: - self._free_request(req_id) - return req_outputs - - def terminate_detokenizer(self) -> None: - self.detokenizer.terminate() - - def _make_request_output( - self, - request: Request, - num_output_tokens: int, - new_output_text: str, - finished: bool, - ) -> RequestOutput: - req_output = self.request_outputs.get(request.request_id) - if req_output is None: - # TODO: Support `n` > 1. - completion_output = CompletionOutput( - index=0, - text="", - token_ids=[], - cumulative_logprob=None, - logprobs=None, # TODO - finish_reason=None, - stop_reason=None, - lora_request=None, - ) - req_output = RequestOutput( - request_id=request.request_id, - prompt=request.prompt, - prompt_token_ids=request.prompt_token_ids, - prompt_logprobs=None, # TODO - outputs=[completion_output], - finished=False, - metrics=None, - lora_request=None, - encoder_prompt=None, - encoder_prompt_token_ids=None, - ) - self.request_outputs[request.request_id] = req_output - - completion_output = req_output.outputs[0] - if request.sampling_params.output_kind == RequestOutputKind.CUMULATIVE: - completion_output.text += new_output_text - completion_output.token_ids = ( - request.output_token_ids[:num_output_tokens]) - elif request.sampling_params.output_kind == RequestOutputKind.DELTA: - completion_output.text = new_output_text - num_prev_tokens = len(completion_output.token_ids) - completion_output.token_ids = request.output_token_ids[ - num_prev_tokens:num_output_tokens] - elif (request.sampling_params.output_kind == - RequestOutputKind.FINAL_ONLY): - if finished: - completion_output.text = request.output_text - completion_output.token_ids = request.output_token_ids - else: - completion_output.text = "" - completion_output.token_ids = [] - - if finished: - completion_output.finish_reason = request.get_finished_reason() - completion_output.stop_reason = request.stop_reason - req_output.finished = finished - return req_output - - def _free_request(self, request_id: str) -> None: - self.requests.pop(request_id, None) - self.num_lagged_steps.pop(request_id, None) - self.request_outputs.pop(request_id, None) + # NOTE: This method returns an empty list if the EngineCore + # step is running. This is not the end of the generation process. + if self.from_core.poll(timeout=0) != 0: + frames = self.from_core.recv_multipart(copy=False) + engine_core_outputs = self.decoder.decode(frames[0].buffer).outputs + request_outputs = self.detokenizer.step(engine_core_outputs) + return request_outputs + + return [] def check_health(self) -> None: if self.tokenizer: self.tokenizer.check_health() - self.model_executor.check_health() + # self.model_executor.check_health() + # TODO: send health check to EngineCore. + + def _make_requests( + self, + request_id: str, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: float, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None + ) -> Tuple[DetokenizerRequest, EngineCoreRequest]: + + # Process inputs. + preprocessed_inputs = self.input_preprocessor.preprocess( + prompt, + request_id=request_id, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) + processed_inputs = self.input_processor(preprocessed_inputs) + self._validate_model_inputs(processed_inputs) + eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) + + assert isinstance(params, SamplingParams) + sampling_params = params.clone() + sampling_params.update_from_generation_config( + self.generation_config_fields, eos_token_id) + + # Make Request for Detokenizer. + detokenizer_request = DetokenizerRequest( + request_id, processed_inputs.get("prompt"), + processed_inputs.get("prompt_token_ids"), + sampling_params.skip_special_tokens, + sampling_params.spaces_between_special_tokens, + sampling_params.output_kind) + + # Make Request for EngineCore. + engine_core_request = EngineCoreRequest( + request_id, processed_inputs.get("prompt"), + processed_inputs.get("prompt_token_ids"), sampling_params, + eos_token_id, arrival_time, lora_request) + + return detokenizer_request, engine_core_request def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs, EncoderDecoderLLMInputs]): diff --git a/vllm/v1/engine/llm_engine_core.py b/vllm/v1/engine/llm_engine_core.py new file mode 100644 index 0000000000000..9007f21129ecf --- /dev/null +++ b/vllm/v1/engine/llm_engine_core.py @@ -0,0 +1,186 @@ +import multiprocessing +from typing import List, Optional, Type + +import msgspec +import zmq + +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ObservabilityConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig) +from vllm.logger import init_logger +from vllm.v1.core.scheduler import Scheduler +from vllm.v1.engine import (LLM_ENGINE_CORE_READY_STR, EngineCoreOutput, + EngineCoreOutputs, EngineCoreRequest) +from vllm.v1.executor.gpu_executor import GPUExecutor +from vllm.v1.request import Request + +logger = init_logger(__name__) + +POLLING_TIMEOUT_MS = 10000 + + +# TODO: better name? LLMEngineProc? +class LLMEngineCore(multiprocessing.Process): + + def __init__( + self, + input_path: str, + output_path: str, + core_ready_path: str, + executor_class: Type[GPUExecutor], + model_config: ModelConfig, + cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + speculative_config: Optional[SpeculativeConfig], + observability_config: Optional[ObservabilityConfig], + prompt_adapter_config: Optional[PromptAdapterConfig], + ): + super().__init__() + self.input_path = input_path + self.output_path = output_path + self.core_ready_path = core_ready_path + self.executor_class = executor_class + self.model_config = model_config + self.cache_config = cache_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.speculative_config = speculative_config + self.load_config = load_config + self.lora_config = lora_config + self.observability_config = observability_config + self.prompt_adapter_config = prompt_adapter_config + + def run(self): + # Initialize these objects after the process is forked. + self.msgpack_encoder = msgspec.msgpack.Encoder() + self.msgpack_decoder = msgspec.msgpack.Decoder(EngineCoreRequest) + + self.ctx = zmq.Context() # type: ignore[attr-defined] + + # Get EngineCoreRequests from the LLMEngine. + self.input_socket = self.ctx.socket(zmq.constants.PULL) + self.input_socket.connect(self.input_path) + + # Send EngineCoreOutput to the LLMEngine. + self.output_socket = self.ctx.socket(zmq.constants.PUSH) + self.output_socket.bind(self.output_path) + + # Setup Model. + self.model_executor = self.executor_class( + model_config=self.model_config, + cache_config=self.cache_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + lora_config=self.lora_config, + speculative_config=self.speculative_config, + load_config=self.load_config, + prompt_adapter_config=self.prompt_adapter_config, + observability_config=self.observability_config, + ) + + # Setup KV Caches. + # NOTE: the cache_config isn updated with the numbers of GPU and CPU + # blocks, which are profiled in the distributed executor. + self._initialize_kv_caches() + + # Setup Scheduler. + self.scheduler = Scheduler(self.scheduler_config, self.cache_config, + self.lora_config) + + # TODO: add heartbeat thread. + + # Send LLM_ENGINE_CORE_READY_STR. + try: + ready_socket = self.ctx.socket(zmq.constants.PUSH) + ready_socket.bind(self.core_ready_path) + ready_socket.send_string(LLM_ENGINE_CORE_READY_STR) + finally: + ready_socket.close(linger=0) + + # Kickoff the busy loop. + self.core_loop() + + def _initialize_kv_caches(self) -> None: + num_gpu_blocks, _ = self.model_executor.determine_num_available_blocks( + ) + + if self.cache_config.num_gpu_blocks_override is not None: + num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override + logger.info( + "Overriding num_gpu_blocks=%d with " + "num_gpu_blocks_override=%d", num_gpu_blocks, + num_gpu_blocks_override) + num_gpu_blocks = num_gpu_blocks_override + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = 0 + self.model_executor.initialize_cache(num_gpu_blocks) + + def core_loop(self): + """Core busy loop of the LLMEngineCore.""" + + while True: + # Poll the input socket until there is work to do. + if not self.scheduler.has_unfinished_requests(): + while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: + logger.debug("Waiting for new requests.") + + # Handle new input from the socket. + self._handle_new_input() + + # Forward pass. + outputs = self._step() + + # Stream outputs to the LLMEngine. + self._send_outputs(outputs) + + def _step(self) -> Optional[List[EngineCoreOutputs]]: + """Schedule, execute, and make output.""" + + if not self.scheduler.has_unfinished_requests(): + return None + + scheduler_output = self.scheduler.schedule() + output = self.model_executor.execute_model(scheduler_output) + engine_core_outputs = self.scheduler.update_from_output( + scheduler_output, output) + return engine_core_outputs + + def _handle_new_input(self): + """Handle new input from the LLMEngine.""" + try: + while self.input_socket.poll(timeout=0) != 0: + frames = self.input_socket.recv_multipart(copy=False) + engine_core_request = self.msgpack_decoder.decode( + frames[0].buffer) + request = Request.from_engine_core_request(engine_core_request) + self.scheduler.add_request(request) + + # TODO: handle abort via another socket + # TODO: handle logits processors via cloudpickle + # TODO: handle profiling + + except Exception as e: + # TODO: handle gracefully + raise e + + def _send_outputs( + self, + engine_core_outputs: Optional[List[EngineCoreOutput]]) -> None: + """Serialize and send output to the LLMEngine.""" + + if engine_core_outputs is None: + return + + outputs = EngineCoreOutputs(outputs=engine_core_outputs) + outputs_serialized = self.msgpack_encoder.encode(outputs) + self.output_socket.send_multipart((outputs_serialized, ), + copy=False, + flags=zmq.NOBLOCK) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index be7d4d165d280..d1b2431f1fbd5 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -1,9 +1,11 @@ import enum from typing import TYPE_CHECKING, List, Optional, Union +from vllm.inputs.data import DecoderOnlyInputs from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams from vllm.sequence import RequestMetrics +from vllm.v1.engine import EngineCoreRequest if TYPE_CHECKING: from vllm.inputs import DecoderOnlyInputs @@ -41,9 +43,21 @@ def __init__( self.prompt_token_ids = inputs["prompt_token_ids"] self.num_prompt_tokens = len(self.prompt_token_ids) self.output_token_ids: List[int] = [] - self.output_text = "" self.num_computed_tokens = 0 + @classmethod + def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": + + return cls( + request_id=request.request_id, + inputs=DecoderOnlyInputs(prompt_token_ids=request.prompt_token_ids, + prompt=request.prompt), + sampling_params=request.sampling_params, + eos_token_id=request.eos_token_id, + arrival_time=request.arrival_time, + lora_request=request.lora_request, + ) + @property def num_tokens(self) -> int: return self.num_prompt_tokens + len(self.output_token_ids) diff --git a/vllm/v1/tokenizer/__init__.py b/vllm/v1/tokenizer/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/vllm/v1/tokenizer/detokenizer.py b/vllm/v1/tokenizer/detokenizer.py deleted file mode 100644 index 4bbcf4717981e..0000000000000 --- a/vllm/v1/tokenizer/detokenizer.py +++ /dev/null @@ -1,215 +0,0 @@ -import multiprocessing -from dataclasses import dataclass -from typing import Dict, List, Optional - -import msgspec -import zmq -from msgspec import msgpack - -from vllm.transformers_utils.detokenizer_utils import ( - convert_prompt_ids_to_tokens, detokenize_incrementally) -from vllm.transformers_utils.tokenizer import get_tokenizer -from vllm.utils import get_open_port - - -class DetokenizerInputs(msgspec.Struct): - - # [num_reqs] - req_ids: List[str] - # A request's prompt token ids is sent to the detokenizer only when - # the request is first detokenized. Otherwise, an empty list is sent. - prompt_token_ids: List[List[int]] - new_token_ids: List[List[int]] - skip_special_tokens: List[bool] - spaces_between_special_tokens: List[bool] - - # [num_free_reqs] - free_req_ids: List[str] - - -class DetokenizerOutputs(msgspec.Struct): - - # [num_reqs] - req_ids: List[str] - detokenized_texts: List[str] - # NOTE(woosuk): The number of the output token ids of each request - # at the time of detokenization. The detokenizer returns this to the engine - # because the request state (including the output token ids) is - # asynchronously updated in the engine, while RequestOutput requires the - # output token ids to be consistent with the detokenized text. - num_output_token_ids: List[int] - - -class Detokenizer: - - def __init__(self, tokenizer_name: str): - # FIXME(woosuk): Currently, the detokenizer is just a hacky prototype. - # For example, it does not terminate properly. We need to improve this. - self.push_port = get_open_port() - self.pull_port = get_open_port() - self.detokenizer = DetokenizerProc(tokenizer_name, self.push_port, - self.pull_port) - self.detokenizer.start() - - self.zmq_context = zmq.Context() - self.push_socket = self.zmq_context.socket(zmq.PUSH) - self.push_socket.connect(f"tcp://localhost:{self.push_port}") - self.pull_socket = self.zmq_context.socket(zmq.PULL) - self.pull_socket.connect(f"tcp://localhost:{self.pull_port}") - self.poller = zmq.Poller() - self.poller.register(self.pull_socket, zmq.POLLIN) - self.msgpack_encoder = msgpack.Encoder() - self.msgpack_decoder = msgpack.Decoder(DetokenizerOutputs) - - def send(self, inputs: DetokenizerInputs) -> None: - self.push_socket.send(self.msgpack_encoder.encode(inputs), - flags=zmq.NOBLOCK) - - def recv(self) -> Optional[DetokenizerOutputs]: - socks = dict(self.poller.poll(timeout=0)) - if self.pull_socket in socks and socks[self.pull_socket] == zmq.POLLIN: - msg = self.pull_socket.recv() - return self.msgpack_decoder.decode(msg) - return None - - def terminate(self) -> None: - self.push_socket.send(b"", flags=zmq.NOBLOCK) - self.detokenizer.join() - - -class DetokenizerProc(multiprocessing.Process): - - def __init__( - self, - tokenizer_name: str, - pull_port: int, - push_port: int, - ): - super().__init__() - self.tokenizer_name = tokenizer_name - # NOTE: The pull_port of the detokenizer should be the same as the - # push_port of the engine. Vice versa. - self.pull_port = pull_port - self.push_port = push_port - - def run(self): - # Initialize these objects after the process is forked since they are - # not picklable. - self.msgpack_encoder = msgpack.Encoder() - self.msgpack_decoder = msgpack.Decoder(DetokenizerInputs) - self.tokenizer = get_tokenizer(self.tokenizer_name) - # req_id -> RequestState - self.request_states: Dict[str, RequestState] = {} - - self.zmq_context = zmq.Context() - self.pull_socket = self.zmq_context.socket(zmq.PULL) - self.pull_socket.bind(f"tcp://*:{self.pull_port}") - self.push_socket = self.zmq_context.socket(zmq.PUSH) - self.push_socket.bind(f"tcp://*:{self.push_port}") - - while True: - message = self.pull_socket.recv() - if message == b"": - # Terminate signal. - break - inputs = self.msgpack_decoder.decode(message) - - for req_id in inputs.free_req_ids: - self.free(req_id) - - detokenized_texts: List[str] = [] - num_output_token_ids: List[int] = [] - num_reqs = len(inputs.req_ids) - for i in range(num_reqs): - req_id = inputs.req_ids[i] - if req_id not in self.request_states: - self.add_request( - request_id=req_id, - prompt_token_ids=inputs.prompt_token_ids[i], - skip_special_tokens=inputs.skip_special_tokens[i], - spaces_between_special_tokens=inputs. - spaces_between_special_tokens[i], - ) - new_str = self.detokenize(req_id, inputs.new_token_ids[i]) - detokenized_texts.append(new_str) - req_state = self.request_states[req_id] - num_output_token_ids.append( - len(req_state.token_ids) - req_state.num_prompt_tokens) - - detokenized = DetokenizerOutputs( - req_ids=inputs.req_ids, - detokenized_texts=detokenized_texts, - num_output_token_ids=num_output_token_ids, - ) - self.push_socket.send(self.msgpack_encoder.encode(detokenized), - flags=zmq.NOBLOCK) - - def add_request( - self, - request_id: str, - prompt_token_ids: List[int], - skip_special_tokens: bool, - spaces_between_special_tokens: bool, - ) -> None: - tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens( - tokenizer=self.tokenizer, - prompt_ids=prompt_token_ids, - skip_special_tokens=skip_special_tokens, - ) - self.request_states[request_id] = RequestState( - req_id=request_id, - token_ids=prompt_token_ids, - tokens=tokens, - num_prompt_tokens=len(prompt_token_ids), - prefix_offset=prefix_offset, - read_offset=read_offset, - skip_special_tokens=skip_special_tokens, - spaces_between_special_tokens=spaces_between_special_tokens, - ) - - def free(self, request_id: str) -> None: - del self.request_states[request_id] - - def detokenize(self, request_id: str, new_token_ids: List[int]) -> str: - # TODO(woosuk): This method becomes very inefficient when the number of - # new_token_ids is more than 1. We need to optimize this. - req_state = self.request_states[request_id] - decoded_text = "" - for new_token_id in new_token_ids: - req_state.token_ids.append(new_token_id) - (new_tokens, new_decoded_token_text, prefix_offset, - read_offset) = detokenize_incrementally( - tokenizer=self.tokenizer, - all_input_ids=req_state.token_ids, - prev_tokens=req_state.tokens, - prefix_offset=req_state.prefix_offset, - read_offset=req_state.read_offset, - skip_special_tokens=req_state.skip_special_tokens, - spaces_between_special_tokens=req_state. - spaces_between_special_tokens, - ) - - req_state.tokens.extend(new_tokens) - req_state.prefix_offset = prefix_offset - req_state.read_offset = read_offset - req_state.output_text += new_decoded_token_text - decoded_text += new_decoded_token_text - return decoded_text - - -@dataclass -class RequestState: - - req_id: str - - token_ids: List[int] - tokens: List[str] - num_prompt_tokens: int - - prefix_offset: int - read_offset: int - - skip_special_tokens: bool - spaces_between_special_tokens: bool - - output_text: str = ""