From f9a0f75542b5dacb945fba0412f0f6345fcb24c4 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 6 Nov 2024 12:46:33 -0800 Subject: [PATCH] Various updates --- benchmarks/backend_request_func.py | 11 +++- vllm/config.py | 44 +++++++++++++ vllm/engine/output_processor/stop_checker.py | 6 +- vllm/outputs.py | 2 +- vllm/v1/engine/__init__.py | 12 +++- vllm/v1/engine/async_llm.py | 66 +++++++++++--------- vllm/v1/engine/core.py | 60 ++++-------------- vllm/v1/engine/detokenizer.py | 21 ++++--- vllm/v1/request.py | 3 +- 9 files changed, 128 insertions(+), 97 deletions(-) diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index a42e70170ba28..447a7ca429761 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -230,6 +230,8 @@ async def async_request_openai_completions( ("completions", "profile") ), "OpenAI Completions API URL must end with 'completions' or 'profile'." + stream = True + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: payload = { "model": request_func_input.model, @@ -238,7 +240,7 @@ async def async_request_openai_completions( "best_of": request_func_input.best_of, "max_tokens": request_func_input.output_len, "logprobs": request_func_input.logprobs, - "stream": True, + "stream": stream, "ignore_eos": request_func_input.ignore_eos, } headers = { @@ -263,9 +265,10 @@ async def async_request_openai_completions( chunk = chunk_bytes.decode("utf-8").removeprefix( "data: ") - if chunk == "[DONE]": + stream_is_done = stream and chunk == "[DONE]" + if not stream or stream_is_done: latency = time.perf_counter() - st - else: + if not stream_is_done: data = json.loads(chunk) # NOTE: Some completion API might have a last @@ -379,10 +382,12 @@ async def async_request_openai_chat_completions( else: output.error = response.reason or "" output.success = False + print("Error reason", response.reason) except Exception: output.success = False exc_info = sys.exc_info() output.error = "".join(traceback.format_exception(*exc_info)) + traceback.print_exc() if pbar: pbar.update(1) diff --git a/vllm/config.py b/vllm/config.py index 91bbbfec4b7b3..5b8f3a863e5fb 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2038,3 +2038,47 @@ def __post_init__(self): self.model_config is not None and self.load_config is not None: self.quant_config = VllmConfig._get_quantization_config( self.model_config, self.load_config) + + def __str__(self): + return ("model=%r, speculative_config=%r, tokenizer=%r, " + "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " + "override_neuron_config=%s, " + "rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, " + "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " + "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " + "pipeline_parallel_size=%d, " + "disable_custom_all_reduce=%s, quantization=%s, " + "enforce_eager=%s, kv_cache_dtype=%s, " + "quantization_param_path=%s, device_config=%s, " + "decoding_config=%r, observability_config=%r, " + "seed=%d, served_model_name=%s, " + "num_scheduler_steps=%d, enable_prefix_caching=%s, " + "use_async_output_proc=%s, mm_processor_kwargs=%s") % \ + (self.model_config.model, self.speculative_config, + self.model_config.tokenizer, + self.model_config.skip_tokenizer_init, + self.model_config.tokenizer_mode, + self.model_config.revision, + self.model_config.override_neuron_config, + self.model_config.rope_scaling, + self.model_config.rope_theta, + self.model_config.tokenizer_revision, + self.model_config.trust_remote_code, + self.model_config.dtype, + self.model_config.max_model_len, + self.load_config.download_dir, + self.load_config.load_format, + self.parallel_config.tensor_parallel_size, + self.parallel_config.pipeline_parallel_size, + self.parallel_config.disable_custom_all_reduce, + self.model_config.quantization, + self.model_config.enforce_eager, + self.cache_config.cache_dtype, + self.model_config.quantization_param_path, + self.device_config.device, self.decoding_config, + self.observability_config, self.model_config.seed, + self.model_config.served_model_name, + self.scheduler_config.num_scheduler_steps, + self.cache_config.enable_prefix_caching, + self.model_config.use_async_output_proc, + self.model_config.mm_processor_kwargs) \ No newline at end of file diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py index e25649b2b0b9f..4b701f81504bb 100644 --- a/vllm/engine/output_processor/stop_checker.py +++ b/vllm/engine/output_processor/stop_checker.py @@ -98,7 +98,11 @@ def check_stop_strings( """Check if any stop strings are matched and truncate sequence output text accordingly. - Returns the stop string if matched or else None. + Returns tuple (stop_string, offset) if matched or else None. + + Where stop_string is the matched stop string and offset is the + length to which output_text should be truncated, or -1 for no + truncation. """ if not new_char_count or not stop: return None diff --git a/vllm/outputs.py b/vllm/outputs.py index 433b0631d524b..abfdb7d328126 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -123,7 +123,7 @@ def new( token_ids: List[int], finished: bool = False, ) -> "RequestOutput": - """Initialize a new "empty" RequestOutput object.""" + """Initialize a new RequestOutput object.""" # TODO: Support `n` > 1. completion_output = CompletionOutput( diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index b60d1da195e7f..80794a209c7f9 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -25,7 +25,7 @@ class DetokenizerRequest: include_stop_str_in_output: bool -class EngineCoreRequest(msgspec.Struct): +class EngineCoreRequest(msgspec.Struct, omit_defaults=True): # NOTE: prompt and prompt_token_ids should be DecoderOnlyInput, # but this object is currently not playing well with msgspec @@ -42,7 +42,10 @@ class EngineCoreRequest(msgspec.Struct): lora_request: Optional[LoRARequest] -class EngineCoreOutput(msgspec.Struct, array_like=True): +class EngineCoreOutput(msgspec.Struct, + array_like=True, + omit_defaults=True, + gc=False): request_id: str new_token_ids: List[int] @@ -51,7 +54,10 @@ class EngineCoreOutput(msgspec.Struct, array_like=True): stop_reason: Union[int, str, None] = None -class EngineCoreOutputs(msgspec.Struct, array_like=True): +class EngineCoreOutputs(msgspec.Struct, + array_like=True, + omit_defaults=True, + gc=False): #NOTE(Nick): We could consider ways to make this more compact, # e.g. columnwise layout and using an int enum for finish/stop reason diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 4c2529b8f34da..1d2129ecd84e2 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -55,6 +55,8 @@ def __init__( # Map (request_id -> Stream) self.request_streams: Dict[str, AsyncStream] = {} + # List of cancelled request ids to be aborted. + self.client_aborted_requests: List[str] = [] # Processor (converts Inputs --> EngineCoreRequests) self.processor = Processor(vllm_config.model_config, @@ -76,28 +78,26 @@ def __init__( # TODO: add background loop shielding # TODO: add AsyncEngineDeadError - self.is_output_handler_running = False + self.output_handler = None @classmethod def from_engine_args( cls, engine_args: AsyncEngineArgs, - engine_config: Optional[VllmConfig] = None, + vllm_config: Optional[VllmConfig] = None, start_engine_loop: bool = True, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, ) -> "AsyncLLMEngine": - """Creates an AsyncLLMEngine from the EngineArgs.""" + """Creates an AsyncLLM from the EngineArgs.""" # Create the engine configs. - if engine_config is None: + if vllm_config is None: vllm_config = engine_args.create_engine_config() - else: - vllm_config = engine_config executor_class = cls._get_executor_cls(vllm_config) - # Create the AsyncLLMEngine. + # Create the AsyncLLM. return cls( vllm_config=vllm_config, executor_class=executor_class, @@ -112,8 +112,8 @@ def shutdown(self): """Shutdown the EngineCore.""" self.engine_core.shutdown() - if hasattr(self, "output_handler"): - self.output_handler.cancel() + if handler := getattr(self, "output_handler", None): + handler.cancel() @classmethod def _get_executor_cls(cls, vllm_config: VllmConfig): @@ -123,15 +123,9 @@ def _add_request_to_streams(self, request_id: str) -> AsyncStream: if request_id in self.request_streams: raise ValueError(f"Request id {request_id} already running.") - # TODO: handle abort. - # IDEA(Nick): we could batch up aborts rather than sending - # them individually, so that we send at most one batch of - # aborts per step (added to any that we're doing due to - # stop string matches for that step) - def _abort(): - pass - - stream = AsyncStream(request_id, _abort) + # Avoid streams having circular ref to parent AsyncLLM object. + aborted_reqs = self.client_aborted_requests + stream = AsyncStream(request_id, aborted_reqs.append) self.request_streams[request_id] = stream return stream @@ -140,20 +134,33 @@ def _send_to_streams(self, request_outputs: List[RequestOutput]): for request_output in request_outputs: request_id = request_output.request_id - assert request_id in self.request_streams - self.request_streams[request_id].put(request_output) + stream = self.request_streams.get(request_id) + if stream is not None: + finished = request_output.finished + stream.put(request_output) + if finished: + self._finish_stream(request_id) - if request_output.finished: - self.request_streams[request_id].finish() - self.request_streams.pop(request_id) + def _finish_stream(self, request_id: str): + stream = self.request_streams.pop(request_id) + if stream is not None: + stream.finish() - async def abort_requests(self, request_ids: List[str]) -> None: + async def _abort_requests(self, request_ids: List[str]) -> None: """Remove request_ids from EngineCore and Detokenizer.""" - if len(request_ids) > 0: + # Include any client cancellations. + client_aborted_reqs = self.client_aborted_requests + if client_aborted_reqs: + self.detokenizer.abort_requests(client_aborted_reqs) + for request_id in client_aborted_reqs: + self._finish_stream(request_id) + request_ids.extend(client_aborted_reqs) + client_aborted_reqs.clear() + + if request_ids: await self.engine_core.abort_requests_async(request_ids) - self.detokenizer.abort_requests(request_ids) async def add_request( self, @@ -205,10 +212,9 @@ async def generate( # We start the output_handler on the first call to generate() so that # we can call __init__ before the event loop starts, which enables us # to handle startup failure gracefully in the OpenAI server. - if not self.is_output_handler_running: + if self.output_handler is None: self.output_handler = asyncio.create_task( self._run_output_handler()) - self.is_output_handler_running = True async for output in await self.add_request( request_id, @@ -241,7 +247,7 @@ async def _run_output_handler(self): self._send_to_streams(request_outputs) # Abort any requests that finished due to stop strings. - await self.abort_requests(reqs_to_abort) + await self._abort_requests(reqs_to_abort) except BaseException as e: logger.error(e) @@ -288,7 +294,7 @@ async def do_log_stats( logger.debug("Called do_log_stats.") async def check_health(self) -> None: - logger.debug("Called do_log_stats.") + logger.debug("Called check_health.") async def start_profile(self) -> None: raise ValueError("Not supported on V1 yet.") diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 2f8767bad0207..86319bede99ad 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -44,50 +44,8 @@ def __init__( assert vllm_config.model_config.task != "embedding" - logger.info( - "Initializing an LLM engine (v%s) with config: " - "model=%r, speculative_config=%r, tokenizer=%r, " - "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " - "override_neuron_config=%s, " - "rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, " - "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " - "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " - "pipeline_parallel_size=%d, " - "disable_custom_all_reduce=%s, quantization=%s, " - "enforce_eager=%s, kv_cache_dtype=%s, " - "quantization_param_path=%s, device_config=%s, " - "decoding_config=%r, observability_config=%r, " - "seed=%d, served_model_name=%s, " - "num_scheduler_steps=%d, enable_prefix_caching=%s, " - "use_async_output_proc=%s, mm_processor_kwargs=%s)", VLLM_VERSION, - vllm_config.model_config.model, vllm_config.speculative_config, - vllm_config.model_config.tokenizer, - vllm_config.model_config.skip_tokenizer_init, - vllm_config.model_config.tokenizer_mode, - vllm_config.model_config.revision, - vllm_config.model_config.override_neuron_config, - vllm_config.model_config.rope_scaling, - vllm_config.model_config.rope_theta, - vllm_config.model_config.tokenizer_revision, - vllm_config.model_config.trust_remote_code, - vllm_config.model_config.dtype, - vllm_config.model_config.max_model_len, - vllm_config.load_config.download_dir, - vllm_config.load_config.load_format, - vllm_config.parallel_config.tensor_parallel_size, - vllm_config.parallel_config.pipeline_parallel_size, - vllm_config.parallel_config.disable_custom_all_reduce, - vllm_config.model_config.quantization, - vllm_config.model_config.enforce_eager, - vllm_config.cache_config.cache_dtype, - vllm_config.model_config.quantization_param_path, - vllm_config.device_config.device, vllm_config.decoding_config, - vllm_config.observability_config, vllm_config.model_config.seed, - vllm_config.model_config.served_model_name, - vllm_config.scheduler_config.num_scheduler_steps, - vllm_config.cache_config.enable_prefix_caching, - vllm_config.model_config.use_async_output_proc, - vllm_config.model_config.mm_processor_kwargs) + logger.info("Initializing an LLM engine (v%s) with config: %s", + VLLM_VERSION, vllm_config) # Setup Model. self.model_executor = executor_class(vllm_config) @@ -129,6 +87,9 @@ def add_request(self, request: EngineCoreRequest): def abort_requests(self, request_ids: List[str]): """Abort requests from the scheduler.""" + # TODO: The scheduler doesn't really need to know the + # specific finish reason, TBD whether we propagate that + # (i.e. client-aborted vs stop criteria met). self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED) @@ -166,7 +127,9 @@ def __init__( self.should_shutdown = should_shutdown # Background Threads and Queues for IO. These enable us to - # overlap ZMQ socket IO with GPU since they release the GIL. + # overlap ZMQ socket IO with GPU since they release the GIL, + # and to overlap some serialization/deserialization with the + # model forward pass. # Threads handle Socket <-> Queues and core_busy_loop uses Queue. self.input_queue = queue.Queue() self.output_queue = queue.Queue() @@ -271,7 +234,6 @@ def make_engine_core_process( def run_engine_core(*args, **kwargs): """Launch EngineCore busy loop in background process.""" - engine_core = None try: engine_core = EngineCoreProc(*args, **kwargs) engine_core.run_busy_loop() @@ -352,12 +314,14 @@ def process_output_socket(self, output_path: str): # Msgpack serialization encoding.. encoder = msgpack.Encoder() + # Reuse send buffer + buffer = bytearray() with self.make_socket(output_path, zmq.constants.PUSH) as socket: while True: engine_core_outputs = self.output_queue.get() outputs = EngineCoreOutputs(outputs=engine_core_outputs) - outputs_serialized = encoder.encode(outputs) - socket.send_multipart((outputs_serialized, ), + encoder.encode_into(outputs, buffer) + socket.send_multipart((buffer, ), copy=False, flags=zmq.NOBLOCK) diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index 75b4044854b7f..1dbf8e75ec478 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple +from typing import Dict, Iterable, List, Optional, Tuple from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger @@ -43,7 +43,7 @@ class IncrementalDetokenizer: tokenizer: AnyTokenizer # Accounting for stop string buffering - buffer_length: int + stop_buffer_length: int _last_output_text_offset: int = 0 @property @@ -67,8 +67,10 @@ def from_new_request( stops = request.stop # Number of chars to hold back when stop strings are to be excluded # from streamed output. - buffer_length = 0 if not stops or request.include_stop_str_in_output \ - else max(len(s) for s in stops) - 1 + if stops and not request.include_stop_str_in_output: + stop_buffer_length = max(len(s) for s in stops) - 1 + else: + stop_buffer_length = 0 return cls( output_text="", @@ -88,7 +90,7 @@ def from_new_request( prompt=request.prompt, prompt_token_ids=request.prompt_token_ids, tokenizer=tokenizer, - buffer_length=buffer_length, + stop_buffer_length=stop_buffer_length, ) def add_tokens( @@ -128,7 +130,7 @@ def add_tokens( decoded_text += new_decoded_token_text - # 2) Evaluate stop criteria + # 2) Evaluate stop criteria. if self.stop: stop = StopChecker.check_stop_strings( output_text=self.output_text, @@ -176,7 +178,7 @@ def _get_next_output_text(self, finished: bool, delta: bool) -> str: this method is returned""" # We return the full output text if the sequence is finished. - buffer_length = 0 if finished else self.buffer_length + buffer_length = 0 if finished else self.stop_buffer_length if not delta: return self.output_text[:-buffer_length] if buffer_length else ( self.output_text) @@ -209,13 +211,12 @@ def has_unfinished_requests(self) -> bool: def abort_requests( self, - request_ids: List[str], + request_ids: Iterable[str], ) -> None: """Remove the request_ids from the Detokenizer.""" for request_id in request_ids: - if request_id in self.request_states: - self.request_states.pop(request_id) + self.request_states.pop(request_id, None) def add_request( self, diff --git a/vllm/v1/request.py b/vllm/v1/request.py index d1b2431f1fbd5..3029d52a1830f 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -50,7 +50,8 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": return cls( request_id=request.request_id, - inputs=DecoderOnlyInputs(prompt_token_ids=request.prompt_token_ids, + inputs=DecoderOnlyInputs(type="token", + prompt_token_ids=request.prompt_token_ids, prompt=request.prompt), sampling_params=request.sampling_params, eos_token_id=request.eos_token_id,