-
-
Notifications
You must be signed in to change notification settings - Fork 5.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[V1] Prototype Fully Async Detokenizer #9725
Changes from all commits
8f8662e
01c4ca8
1ad8a48
f9084f6
72bccd9
a6cab52
885ed16
3ed66cf
8ae8ce9
5c72515
f9b33fa
82539b9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,18 +14,19 @@ | |
from vllm.inputs.preprocess import InputPreprocessor | ||
from vllm.logger import init_logger | ||
from vllm.lora.request import LoRARequest | ||
from vllm.outputs import CompletionOutput, RequestOutput | ||
from vllm.pooling_params import PoolingParams | ||
from vllm.prompt_adapter.request import PromptAdapterRequest | ||
from vllm.sampling_params import RequestOutputKind, SamplingParams | ||
from vllm.sampling_params import SamplingParams | ||
from vllm.transformers_utils.config import try_get_generation_config | ||
from vllm.transformers_utils.tokenizer_group import ( | ||
BaseTokenizerGroup, init_tokenizer_from_configs) | ||
from vllm.usage.usage_lib import UsageContext | ||
from vllm.v1.core.scheduler import Scheduler | ||
from vllm.v1.executor.gpu_executor import GPUExecutor | ||
from vllm.v1.request import Request, RequestStatus | ||
from vllm.v1.tokenizer.detokenizer import Detokenizer, DetokenizerInputs | ||
from vllm.v1.tokenizer.detokenizer import (Detokenizer, DetokenizerInputData, | ||
DetokenizerInputs, | ||
DetokenizerNewRequest) | ||
from vllm.version import __version__ as VLLM_VERSION | ||
|
||
logger = init_logger(__name__) | ||
|
@@ -48,6 +49,7 @@ def __init__( | |
prompt_adapter_config: Optional[PromptAdapterConfig], | ||
executor_class: Type[GPUExecutor], | ||
log_stats: bool, | ||
output_socket_path: str, | ||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, | ||
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, | ||
input_registry: InputRegistry = INPUT_REGISTRY, | ||
|
@@ -132,7 +134,8 @@ def __init__( | |
# Ping the tokenizer to ensure liveness if it runs in a | ||
# different process. | ||
self.tokenizer.ping() | ||
self.detokenizer = Detokenizer(self.model_config.tokenizer) | ||
self.detokenizer = Detokenizer(self.model_config.tokenizer, | ||
output_socket_path) | ||
|
||
self.generation_config_fields = _load_generation_config_dict( | ||
model_config) | ||
|
@@ -142,18 +145,6 @@ def __init__( | |
self.input_processor = input_registry.create_input_processor( | ||
model_config) | ||
|
||
# Request id -> Request | ||
self.requests: Dict[str, Request] = {} | ||
# NOTE(woosuk): Now that the detokenizer works asynchronously, we need | ||
# to keep track of how many steps each request has been lagged behind | ||
# in terms of detokenization. | ||
# Request id -> how many detokenizer steps the request should wait for. | ||
self.num_lagged_steps: Dict[str, int] = {} | ||
# OPTIMIZATION: Cache the request output and update it incrementally. | ||
# This is used to avoid creating a new RequestOutput object every step. | ||
# Request id -> RequestOutput | ||
self.request_outputs: Dict[str, RequestOutput] = {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The key benefit of this PR in terms of code complexity is that we simplify the |
||
|
||
self.model_executor = executor_class( | ||
model_config=model_config, | ||
cache_config=cache_config, | ||
|
@@ -254,9 +245,8 @@ def _add_processed_request( | |
# TODO(woosuk): Support encoder-decoder models. | ||
req = Request(request_id, processed_inputs, params, eos_token_id, | ||
arrival_time) | ||
self.requests[request_id] = req | ||
self.num_lagged_steps[request_id] = 0 | ||
self.scheduler.add_request(req) | ||
self._add_to_detokenizer(req) | ||
|
||
def stop_remote_worker_execution_loop(self) -> None: | ||
raise NotImplementedError("TP not implemented yet.") | ||
|
@@ -300,155 +290,62 @@ def add_request( | |
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: | ||
self.scheduler.finish_requests(request_id, | ||
RequestStatus.FINISHED_ABORTED) | ||
self._free_request(request_id) | ||
# TODO(robertgshaw2): send msg to Detokenizer to free. | ||
# (maybe this already happens, need to check) | ||
|
||
def get_num_unfinished_requests(self) -> int: | ||
"""Gets the number of unfinished requests.""" | ||
return len(self.requests) | ||
return self.scheduler.get_num_unfinished_requests() | ||
|
||
def has_unfinished_requests(self) -> bool: | ||
"""Returns True if there are unfinished requests.""" | ||
return len(self.requests) > 0 | ||
|
||
def step(self) -> List[RequestOutput]: | ||
# NOTE(woosuk): This method may return an empty list when the | ||
# detokenizer is still processing the outputs. This should not be | ||
# considered as the end of the generation process. | ||
# FIXME(woosuk): Currently, the step method is inefficient because it | ||
# creates RequestOutput objects for all running requests, while they | ||
# may not be needed unless the output is streamed to the client. | ||
return self.scheduler.has_unfinished_requests() | ||
|
||
def step(self) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This breaks the API, as However, I will note that this is already how we use the |
||
"""Schedule, execute, and send output to Detokenizer.""" | ||
if self.scheduler.has_unfinished_requests(): | ||
scheduler_output = self.scheduler.schedule() | ||
output = self.model_executor.execute_model(scheduler_output) | ||
sampled = self.scheduler.update_from_output( | ||
scheduler_output, output) | ||
self.send_to_detokenizer(sampled) | ||
req_outputs = self.recv_from_detokenizer() | ||
return req_outputs | ||
|
||
def send_to_detokenizer(self, sampled: List[Tuple[Request, int]]) -> None: | ||
inputs = DetokenizerInputs( | ||
req_ids=[], | ||
prompt_token_ids=[], | ||
new_token_ids=[], | ||
skip_special_tokens=[], | ||
spaces_between_special_tokens=[], | ||
free_req_ids=[], # TODO(woosuk): Implement freeing. | ||
def _add_to_detokenizer(self, request: Request) -> None: | ||
"""Create DetokenizerNewRequest and send to Detokenizer.""" | ||
|
||
new_request = DetokenizerNewRequest( | ||
request_id=request.request_id, | ||
prompt=request.prompt, | ||
prompt_token_ids=request.prompt_token_ids, | ||
skip_special_tokens=request.sampling_params.skip_special_tokens, | ||
spaces_between_special_tokens=request.sampling_params. | ||
spaces_between_special_tokens, | ||
output_kind=request.sampling_params.output_kind, | ||
) | ||
for req, num_tokens in sampled: | ||
inputs.req_ids.append(req.request_id) | ||
if len(req.output_token_ids) == num_tokens: | ||
# The request is first detokenized. | ||
inputs.prompt_token_ids.append(req.prompt_token_ids) | ||
else: | ||
# The prompt token ids are already cached in the detokenizer. | ||
inputs.prompt_token_ids.append([]) | ||
inputs.new_token_ids.append(req.output_token_ids[-num_tokens:]) | ||
inputs.skip_special_tokens.append( | ||
req.sampling_params.skip_special_tokens) | ||
inputs.spaces_between_special_tokens.append( | ||
req.sampling_params.spaces_between_special_tokens) | ||
|
||
# Update the number of lagged steps. | ||
self.num_lagged_steps[req.request_id] += 1 | ||
self.detokenizer.send(inputs) | ||
|
||
def recv_from_detokenizer(self) -> List[RequestOutput]: | ||
detokenizer_output = self.detokenizer.recv() | ||
if detokenizer_output is None: | ||
return [] | ||
|
||
req_outputs: List[RequestOutput] = [] | ||
num_reqs = len(detokenizer_output.req_ids) | ||
for i in range(num_reqs): | ||
req_id = detokenizer_output.req_ids[i] | ||
if req_id not in self.requests: | ||
# The request has been aborted while the detokenizer was | ||
# processing the outputs. | ||
continue | ||
|
||
req = self.requests[req_id] | ||
req.output_text += detokenizer_output.detokenized_texts[i] | ||
|
||
self.num_lagged_steps[req_id] -= 1 | ||
finished = (self.num_lagged_steps[req_id] == 0 | ||
and req.is_finished()) | ||
req_output = self._make_request_output( | ||
req, detokenizer_output.num_output_token_ids[i], | ||
detokenizer_output.detokenized_texts[i], finished) | ||
req_outputs.append(req_output) | ||
|
||
if finished: | ||
self._free_request(req_id) | ||
return req_outputs | ||
|
||
self.detokenizer.add_request(new_request) | ||
|
||
def send_to_detokenizer(self, sampled: List[Tuple[Request, int]]) -> None: | ||
"""Send new tokens to Detokenizer.""" | ||
|
||
# TODO(robertgshaw2): We could avoid this conversion loop by either/or: | ||
# - scheduler.update_from_output() creates DetokenizerInputData | ||
# - serializing and sending the Requests directly to the Detokenizer | ||
# The negative of this is that the Detokenizer is then more coupled. | ||
input_data = [ | ||
DetokenizerInputData( | ||
request_id=req.request_id, | ||
new_token_ids=req.output_token_ids[-num_tokens:], | ||
finished=req.is_finished(), | ||
finish_reason=req.get_finished_reason(), | ||
stop_reason=req.stop_reason) for req, num_tokens in sampled | ||
] | ||
|
||
self.detokenizer.send(DetokenizerInputs(data=input_data)) | ||
|
||
def terminate_detokenizer(self) -> None: | ||
self.detokenizer.terminate() | ||
|
||
def _make_request_output( | ||
self, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We reduce overhead because:
|
||
request: Request, | ||
num_output_tokens: int, | ||
new_output_text: str, | ||
finished: bool, | ||
) -> RequestOutput: | ||
req_output = self.request_outputs.get(request.request_id) | ||
if req_output is None: | ||
# TODO: Support `n` > 1. | ||
completion_output = CompletionOutput( | ||
index=0, | ||
text="", | ||
token_ids=[], | ||
cumulative_logprob=None, | ||
logprobs=None, # TODO | ||
finish_reason=None, | ||
stop_reason=None, | ||
lora_request=None, | ||
) | ||
req_output = RequestOutput( | ||
request_id=request.request_id, | ||
prompt=request.prompt, | ||
prompt_token_ids=request.prompt_token_ids, | ||
prompt_logprobs=None, # TODO | ||
outputs=[completion_output], | ||
finished=False, | ||
metrics=None, | ||
lora_request=None, | ||
encoder_prompt=None, | ||
encoder_prompt_token_ids=None, | ||
) | ||
self.request_outputs[request.request_id] = req_output | ||
|
||
completion_output = req_output.outputs[0] | ||
if request.sampling_params.output_kind == RequestOutputKind.CUMULATIVE: | ||
completion_output.text += new_output_text | ||
completion_output.token_ids = ( | ||
request.output_token_ids[:num_output_tokens]) | ||
elif request.sampling_params.output_kind == RequestOutputKind.DELTA: | ||
completion_output.text = new_output_text | ||
num_prev_tokens = len(completion_output.token_ids) | ||
completion_output.token_ids = request.output_token_ids[ | ||
num_prev_tokens:num_output_tokens] | ||
elif (request.sampling_params.output_kind == | ||
RequestOutputKind.FINAL_ONLY): | ||
if finished: | ||
completion_output.text = request.output_text | ||
completion_output.token_ids = request.output_token_ids | ||
else: | ||
completion_output.text = "" | ||
completion_output.token_ids = [] | ||
|
||
if finished: | ||
completion_output.finish_reason = request.get_finished_reason() | ||
completion_output.stop_reason = request.stop_reason | ||
req_output.finished = finished | ||
return req_output | ||
|
||
def _free_request(self, request_id: str) -> None: | ||
self.requests.pop(request_id, None) | ||
self.num_lagged_steps.pop(request_id, None) | ||
self.request_outputs.pop(request_id, None) | ||
|
||
def check_health(self) -> None: | ||
if self.tokenizer: | ||
self.tokenizer.check_health() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be okay, but I want to double check. If this is a problem, we can convert this to a separate "error" socket and poll on the error and output socket from the
MQLLMEngineClient