Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Refactor LLMEngine To Use Multiprocessing #9741

Closed
wants to merge 26 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from collections import deque
from dataclasses import dataclass
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
from typing import Deque, Dict, Iterable, List, Optional, Set, Union

from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
from vllm.sampling_params import SamplingParams
from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.engine import EngineCoreOutput
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus

Expand Down Expand Up @@ -227,13 +228,13 @@ def update_from_output(
self,
scheduler_output: "SchedulerOutput",
model_runner_output: "ModelRunnerOutput",
) -> List[Tuple[Request, int]]:
) -> List[EngineCoreOutput]:
# NOTE(robertgshaw2): Should this method be in EngineCore instead?
# NOTE(woosuk): This method doesn't consider speculative decoding.
sampled_token_ids = model_runner_output.sampled_token_ids_cpu.tolist()
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
new_running: List[Request] = []
# (request, num_sampled_tokens)
sampled: List[Tuple[Request, int]] = []
engine_core_outputs: List[EngineCoreOutput] = []
for request in self.running:
req_id = request.request_id
request.num_computed_tokens += num_scheduled_tokens[req_id]
Expand All @@ -247,17 +248,30 @@ def update_from_output(
# generates at most one token at each step.
token_id = sampled_token_ids[req_index]
request.output_token_ids.append(token_id)
sampled.append((request, 1))
num_new_tokens = 1

# TODO: Update the KV cache manager for prefix caching.

# Check if the request is finished.
# Check for stop and update request state.
# This must be called before me make the EngineCoreOutput.
stopped = self._check_stop(request)

# Add EngineCoreOutput for this Request.
output = EngineCoreOutput(
request_id=req_id,
new_token_ids=request.output_token_ids[-num_new_tokens:],
finished=request.is_finished(),
finish_reason=request.get_finished_reason(),
stop_reason=request.stop_reason)
engine_core_outputs.append(output)

# Breakout of the loop.
if stopped:
continue

new_running.append(request)
self.running = new_running
return sampled
return engine_core_outputs

def _check_stop(self, request: Request) -> bool:
if (request.num_tokens >= self.max_model_len
Expand Down
51 changes: 51 additions & 0 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from dataclasses import dataclass
from typing import List, Optional, Union

import msgspec

from vllm.lora.request import LoRARequest
from vllm.sampling_params import RequestOutputKind, SamplingParams

LLM_ENGINE_CORE_READY_STR = "READY"


@dataclass
class DetokenizerRequest:

request_id: str
prompt: Optional[str]
prompt_token_ids: List[int]
skip_special_tokens: bool
spaces_between_special_tokens: bool
output_kind: RequestOutputKind


class EngineCoreRequest(msgspec.Struct):

# NOTE: prompt and prompt_token_ids should be DecoderOnlyInput,
# but this is not playing well with msgspec due to circular
# imports and weird typing we have going on in data.py

request_id: str
prompt: Optional[str]
prompt_token_ids: List[int]
sampling_params: SamplingParams
eos_token_id: Optional[int]
arrival_time: float
lora_request: Optional[LoRARequest]


@dataclass
class EngineCoreOutput:

request_id: str
new_token_ids: List[int]
finished: bool
finish_reason: Optional[str] = None
stop_reason: Union[int, str, None] = None


class EngineCoreOutputs(msgspec.Struct):

# [num_reqs]
outputs: List[EngineCoreOutput]
216 changes: 216 additions & 0 deletions vllm/v1/engine/detokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
from dataclasses import dataclass
from typing import Dict, List, Optional

from vllm.logger import init_logger
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import RequestOutputKind
from vllm.transformers_utils.detokenizer_utils import (
AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally)
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.v1.engine import DetokenizerRequest, EngineCoreOutput

logger = init_logger(__name__)


@dataclass
class DetokenizerRequestState:

# Generation data
output_text: str
tokens: List[str]
token_ids: List[int]

# Metadata for incremental detokenization
prefix_offset: int
read_offset: int

# Parameters for detokenization
skip_special_tokens: bool
spaces_between_special_tokens: bool
output_kind: RequestOutputKind

# Request output (Cached + updated incrementally)
request_output: RequestOutput

@classmethod
def from_new_request(cls, tokenizer: AnyTokenizer,
request: DetokenizerRequest):

tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens(
tokenizer=tokenizer,
prompt_ids=request.prompt_token_ids,
skip_special_tokens=request.skip_special_tokens,
)

request_output = cls._initialize_request_output(
request.request_id,
request.prompt,
request.prompt_token_ids,
)

return cls(
output_text="",
tokens=tokens,
# Detokenizer mutates this list, so need a unique copy.
token_ids=request.prompt_token_ids.copy(),
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=request.skip_special_tokens,
spaces_between_special_tokens=request.
spaces_between_special_tokens,
output_kind=request.output_kind,
request_output=request_output)

@staticmethod
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this a @classmethod for RequestOutput

def _initialize_request_output(
request_id: str, prompt: str,
prompt_token_ids: List[int]) -> RequestOutput:
"""Initialize a new RequestOutput object."""

# TODO: Support `n` > 1.
completion_output = CompletionOutput(
index=0,
text="",
token_ids=[],
cumulative_logprob=None,
logprobs=None, # TODO
finish_reason=None,
stop_reason=None,
lora_request=None,
)

return RequestOutput(
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
prompt_logprobs=None, # TODO
outputs=[completion_output],
finished=False,
metrics=None,
lora_request=None,
encoder_prompt=None,
encoder_prompt_token_ids=None,
)


class Detokenizer:

def __init__(self, tokenizer_name: str):
self.tokenizer = get_tokenizer(tokenizer_name)

# Request id -> DetokenizerRequestState
self.request_states: Dict[str, DetokenizerRequestState] = {}

def get_num_unfinished_requests(self):
return len(self.request_states)

def has_unfinished_requests(self) -> bool:
return len(self.request_states) > 0

def add_request(self, request: DetokenizerRequest) -> None:
"""Add new request to the Detokenizer."""

assert (request.request_id not in self.request_states)

request_state = DetokenizerRequestState.from_new_request(
self.tokenizer, request)
self.request_states[request.request_id] = request_state

def step(
self, encore_core_outputs: List[EngineCoreOutput]
) -> List[RequestOutput]:
"""Update the detokenizer state with the new tokens from EngineCore."""

request_outputs: List[RequestOutput] = []
for engine_core_output in encore_core_outputs:
request_id = engine_core_output.request_id
request_state = self.request_states[request_id]

# Detokenize and update state.
self._update_request_state(
tokenizer=self.tokenizer,
request_state=request_state,
new_token_ids=engine_core_output.new_token_ids,
finished=engine_core_output.finished,
finish_reason=engine_core_output.finish_reason,
stop_reason=engine_core_output.stop_reason,
)
request_outputs.append(request_state.request_output)

# Free completed requests.
if engine_core_output.finished:
self._free(request_id)

# Send RequestOutputs to EngineClient.
return request_outputs

def _free(self, request_id: str) -> None:
"""Remove the request from the RequestState tracker."""

# TODO(robertgshaw2): should this be a del?
assert request_id in self.request_states
self.request_states.pop(request_id)

@staticmethod
def _update_request_state(
tokenizer: AnyTokenizer,
request_state: DetokenizerRequestState,
new_token_ids: List[int],
finished: bool,
finish_reason: Optional[str],
stop_reason: Optional[str],
) -> None:
"""
Update RequestState for the request_id by:
1) Detokenize the new token ids incrementally.
2) Update the RequestOutput with the new text.
"""

# 1) Detokenize the new token ids incrementally.
# TODO(woosuk): This method becomes very inefficient when the number of
# new_token_ids is more than 1. We need to optimize this.
decoded_text = ""
for new_token_id in new_token_ids:
request_state.token_ids.append(new_token_id)
(new_tokens, new_decoded_token_text, prefix_offset,
read_offset) = detokenize_incrementally(
tokenizer=tokenizer,
all_input_ids=request_state.token_ids,
prev_tokens=request_state.tokens,
prefix_offset=request_state.prefix_offset,
read_offset=request_state.read_offset,
skip_special_tokens=request_state.skip_special_tokens,
spaces_between_special_tokens=request_state.
spaces_between_special_tokens,
)

request_state.tokens.extend(new_tokens)
request_state.prefix_offset = prefix_offset
request_state.read_offset = read_offset
request_state.output_text += new_decoded_token_text

decoded_text += new_decoded_token_text

# 2) Update the RequestOutput object with the new text.
request_output = request_state.request_output
completion_output = request_output.outputs[0]
if request_state.output_kind == RequestOutputKind.CUMULATIVE:
completion_output.text += decoded_text
completion_output.token_ids = request_state.token_ids
elif request_state.output_kind == RequestOutputKind.DELTA:
completion_output.text = decoded_text
num_prev_tokens = len(completion_output.token_ids)
completion_output.token_ids = request_state.token_ids[
num_prev_tokens:]
elif request_state.output_kind == RequestOutputKind.FINAL_ONLY:
if finished:
completion_output.text = request_state.output_text
completion_output.token_ids = request_state.token_ids
else:
completion_output.text = ""
completion_output.token_ids = []

if finished:
completion_output.finish_reason = finish_reason
completion_output.stop_reason = stop_reason
request_output.finished = finished
Loading