Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
robertgshaw2-redhat committed Nov 6, 2024
1 parent 7d3c114 commit cf5e63c
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 45 deletions.
105 changes: 65 additions & 40 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
from typing import AsyncGenerator, Dict, List, Mapping, Optional, Type, Union
from functools import partial
from typing import (AsyncGenerator, Dict, List, Mapping, Optional, Type,
Union)

from vllm.config import ModelConfig, VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
Expand Down Expand Up @@ -53,15 +55,15 @@ def __init__(
enable_lora=bool(vllm_config.lora_config))
self.tokenizer.ping()

# Map (request_id -> Stream)
# Request streams (map of request_id -> AsyncStream).
self.request_streams: Dict[str, AsyncStream] = {}

# Processor (converts Inputs --> EngineCoreRequests)
# Processor (converts Inputs --> EngineCoreRequests).
self.processor = Processor(vllm_config.model_config,
vllm_config.lora_config, self.tokenizer,
input_registry)

# Detokenizer (converts EngineCoreOutputs --> RequestOutput)
# Detokenizer (converts EngineCoreOutputs --> RequestOutput).
self.detokenizer = Detokenizer(vllm_config.model_config.tokenizer)

# EngineCore (starts the engine in background process).
Expand All @@ -73,9 +75,6 @@ def __init__(
asyncio_mode=True,
)

# TODO: add background loop shielding
# TODO: add AsyncEngineDeadError

self.is_output_handler_running = False

@classmethod
Expand Down Expand Up @@ -119,42 +118,71 @@ def shutdown(self):
def _get_executor_cls(cls, vllm_config: VllmConfig):
return GPUExecutor

def _add_request_to_streams(self, request_id: str) -> AsyncStream:
async def _abort_requests(
self,
request_ids: Union[str, List[str]],
*,
exception: Optional[Union[BaseException, Type[BaseException]]] = None,
verbose: bool,
) -> None:
"""
Abort requests. This function is called in two places:
* In output_handler_loop, if stop string is detected
* In iterate_with_cancellation (inside AsyncStream) when
a request disconnects. This function is as a callback.
"""

# Convert to a list if we got a single request_id str.
if isinstance(request_ids, str):
request_ids = [request_ids]

# Abort in EngineCore and Detokenizer.
await self.engine_core.abort_requests_async(request_ids)
self.detokenizer.abort_requests(request_ids)

# Remove from the request streams.
for request_id in request_ids:
stream = self.request_streams.pop(request_id, None)
if stream is not None:
stream.finish(exception=exception)

if verbose:
logger.info("Aborted request %s.", request_id)

def _add_request_to_streams(
self,
request_id: str,
verbose: bool = False,
) -> AsyncStream:
"""Add a request to the request request streams."""

if request_id in self.request_streams:
raise ValueError(f"Request id {request_id} already running.")

# TODO: handle abort.
# IDEA(Nick): we could batch up aborts rather than sending
# them individually, so that we send at most one batch of
# aborts per step (added to any that we're doing due to
# stop string matches for that step)
def _abort():
pass

stream = AsyncStream(request_id, _abort)
abort_callback = partial(self._abort_requests, verbose=verbose)
stream = AsyncStream(request_id, abort_callback)
self.request_streams[request_id] = stream

if verbose:
logger.info("Added request %s.", request_id)

return stream

def _send_to_streams(self, request_outputs: List[RequestOutput]):
"""Put the RequestOutputs into the corresponding AsyncStreams"""
def _process_request_outputs(self, request_outputs: List[RequestOutput]):
"""Put the outputs in streams and remove from tracker if finished."""

for request_output in request_outputs:
request_id = request_output.request_id
assert request_id in self.request_streams

# Each request in the API server pulls from these streams.
self.request_streams[request_id].put(request_output)

# If finished, remove from the tracker.
if request_output.finished:
self.request_streams[request_id].finish()
self.request_streams.pop(request_id)

async def abort_requests(self, request_ids: List[str]) -> None:
"""Remove request_ids from EngineCore and Detokenizer."""

if len(request_ids) > 0:
await self.engine_core.abort_requests_async(request_ids)
self.detokenizer.abort_requests(request_ids)

async def add_request(
self,
request_id: str,
Expand All @@ -166,12 +194,13 @@ async def add_request(
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
"""Add request_id to the EngineCore and return a Generator."""

if self.detokenizer.is_request_active(request_id):
raise KeyError(f"Request {request_id} already exists.")

# 1) Make AsyncStream and add to self.request_streams.
stream = self._add_request_to_streams(request_id)
# 1) Create a new request in the RequestTracker.
stream = self._add_request_to_streams(request_id, verbose=True)

# 2) Convert input --> DetokenizerRequest / EngineCoreRequest.
detokenizer_req, engine_core_req = self.processor.process_inputs(
Expand All @@ -184,6 +213,7 @@ async def add_request(
# 4) Add the EngineCoreRequest to EngineCore (separate process).
await self.engine_core.add_request_async(engine_core_req)

# 5) Return the generator.
return stream.generator()

# TODO: we should support multiple prompts in one call, as you
Expand Down Expand Up @@ -222,26 +252,21 @@ async def generate(
yield output

async def _run_output_handler(self):

# TODO: add weakref from current AsyncLLMEngine
# TODO: shutdown remote worker execution loop (once TP enabled)

logger.debug("Starting output handler busy loop in background loop.")
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""

try:
while True:
# Get EngineCoreOutput from the EngineCore.
# 1) Pull EngineCoreOutput from the EngineCore.
outputs = await self.engine_core.get_output_async()

# Detokenize based on the output.
# 2) Detokenize based on the output.
request_outputs, reqs_to_abort = self.detokenizer.step(outputs)

# Put the RequestOutputs into the per-request AsyncStream.
# NOTE(rob): we could do the streaming in the detokenizer.
self._send_to_streams(request_outputs)
# 3) Put the RequestOutputs into the per-request AsyncStreams.
self._process_request_outputs(request_outputs)

# Abort any requests that finished due to stop strings.
await self.abort_requests(reqs_to_abort)
await self._abort_requests(reqs_to_abort, verbose=False)

except BaseException as e:
logger.error(e)
Expand All @@ -250,7 +275,7 @@ async def _run_output_handler(self):
# TODO: can we eliminate these?

async def abort(self, request_id: str) -> None:
# Note: Who Calls this?
# Note: Who Calls this? I dont think this is actually used.
raise ValueError("Not Supported on V1 yet.")

def encode(
Expand Down
9 changes: 5 additions & 4 deletions vllm/v1/engine/async_stream.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from typing import Any, AsyncGenerator, Callable, Optional, Type, Union
from typing import Any, Awaitable, AsyncGenerator, Callable, Optional, Type, Union

from vllm.outputs import EmbeddingRequestOutput, RequestOutput

Expand All @@ -10,9 +10,10 @@ class AsyncStream:

STOP_ITERATION = Exception() # Sentinel

def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
def __init__(self, request_id: str,
abort_callback: Callable[[str], Awaitable[None]]) -> None:
self.request_id = request_id
self._cancel = cancel
self._cancel = abort_callback
self._queue: asyncio.Queue = asyncio.Queue()
self._finished = False

Expand Down Expand Up @@ -46,7 +47,7 @@ async def generator(
raise result
yield result
except GeneratorExit:
self._cancel(self.request_id)
await self._cancel(self.request_id)
raise asyncio.CancelledError from None

@staticmethod
Expand Down
3 changes: 2 additions & 1 deletion vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,4 +225,5 @@ async def add_request_async(self, request: EngineCoreRequest) -> None:
await self._send_input(EngineCoreRequestType.ADD, request)

async def abort_requests_async(self, request_ids: List[str]) -> None:
await self._send_input(EngineCoreRequestType.ABORT, request_ids)
if len(request_ids) > 0:
await self._send_input(EngineCoreRequestType.ABORT, request_ids)

0 comments on commit cf5e63c

Please sign in to comment.