Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Prototype Fully Async Detokenizer #9725

Closed
wants to merge 12 commits into from
2 changes: 1 addition & 1 deletion vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(self, ipc_path: str, engine_config: EngineConfig):

# Receive streams of RequestOutput from the MQLLMEngine.
self.output_socket: Socket = self.context.socket(zmq.constants.PULL)
self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}")
self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}")

# IPC path for acking heartbeats.
self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL)
Expand Down
15 changes: 14 additions & 1 deletion vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ def __init__(self,
# the python object to be reused again.
kwargs['use_cached_outputs'] = True

# For V1 Engine, pass down the output socket path, since
# the LLMEngine needs to pass it to the Detokenizer.
if VLLM_USE_V1:
kwargs['output_socket_path'] = f"{ipc_path}{IPC_OUTPUT_EXT}"

self.engine = LLMEngine(*args, **kwargs)
self.log_requests = log_requests

Expand All @@ -95,8 +100,10 @@ def __init__(self,
self.input_socket.bind(f"{ipc_path}{IPC_INPUT_EXT}")

# Send output stream back to client.
# TODO(robertgshaw2): this currently uses the same path as
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be okay, but I want to double check. If this is a problem, we can convert this to a separate "error" socket and poll on the error and output socket from the MQLLMEngineClient

# the Detokenizer output socket. This may or may not be okay.
self.output_socket = self.ctx.socket(zmq.constants.PUSH)
self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}")
self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}")

# Send heartbeats back to client.
self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH)
Expand Down Expand Up @@ -227,6 +234,12 @@ def run_engine_loop(self):

# Send request outputs (if async, done in engine_step callback).
if not self.use_async_sockets:

# In V1 Engine, Detokenizer sends the outputs to EngineClient.
# note: if request_outputs=None, self._send-outputs is a no-op.
if VLLM_USE_V1:
assert request_outputs is None

self._send_outputs(request_outputs)

def engine_step(self) -> List[RequestOutput]:
Expand Down
195 changes: 46 additions & 149 deletions vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,19 @@
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.tokenizer_group import (
BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import UsageContext
from vllm.v1.core.scheduler import Scheduler
from vllm.v1.executor.gpu_executor import GPUExecutor
from vllm.v1.request import Request, RequestStatus
from vllm.v1.tokenizer.detokenizer import Detokenizer, DetokenizerInputs
from vllm.v1.tokenizer.detokenizer import (Detokenizer, DetokenizerInputData,
DetokenizerInputs,
DetokenizerNewRequest)
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)
Expand All @@ -48,6 +49,7 @@ def __init__(
prompt_adapter_config: Optional[PromptAdapterConfig],
executor_class: Type[GPUExecutor],
log_stats: bool,
output_socket_path: str,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
Expand Down Expand Up @@ -132,7 +134,8 @@ def __init__(
# Ping the tokenizer to ensure liveness if it runs in a
# different process.
self.tokenizer.ping()
self.detokenizer = Detokenizer(self.model_config.tokenizer)
self.detokenizer = Detokenizer(self.model_config.tokenizer,
output_socket_path)

self.generation_config_fields = _load_generation_config_dict(
model_config)
Expand All @@ -142,18 +145,6 @@ def __init__(
self.input_processor = input_registry.create_input_processor(
model_config)

# Request id -> Request
self.requests: Dict[str, Request] = {}
# NOTE(woosuk): Now that the detokenizer works asynchronously, we need
# to keep track of how many steps each request has been lagged behind
# in terms of detokenization.
# Request id -> how many detokenizer steps the request should wait for.
self.num_lagged_steps: Dict[str, int] = {}
# OPTIMIZATION: Cache the request output and update it incrementally.
# This is used to avoid creating a new RequestOutput object every step.
# Request id -> RequestOutput
self.request_outputs: Dict[str, RequestOutput] = {}
Copy link
Collaborator Author

@robertgshaw2-redhat robertgshaw2-redhat Oct 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The key benefit of this PR in terms of code complexity is that we simplify the LLMEngine by no longer needing to keep track of these data.


self.model_executor = executor_class(
model_config=model_config,
cache_config=cache_config,
Expand Down Expand Up @@ -254,9 +245,8 @@ def _add_processed_request(
# TODO(woosuk): Support encoder-decoder models.
req = Request(request_id, processed_inputs, params, eos_token_id,
arrival_time)
self.requests[request_id] = req
self.num_lagged_steps[request_id] = 0
self.scheduler.add_request(req)
self._add_to_detokenizer(req)

def stop_remote_worker_execution_loop(self) -> None:
raise NotImplementedError("TP not implemented yet.")
Expand Down Expand Up @@ -300,155 +290,62 @@ def add_request(
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
self.scheduler.finish_requests(request_id,
RequestStatus.FINISHED_ABORTED)
self._free_request(request_id)
# TODO(robertgshaw2): send msg to Detokenizer to free.
# (maybe this already happens, need to check)

def get_num_unfinished_requests(self) -> int:
"""Gets the number of unfinished requests."""
return len(self.requests)
return self.scheduler.get_num_unfinished_requests()

def has_unfinished_requests(self) -> bool:
"""Returns True if there are unfinished requests."""
return len(self.requests) > 0

def step(self) -> List[RequestOutput]:
# NOTE(woosuk): This method may return an empty list when the
# detokenizer is still processing the outputs. This should not be
# considered as the end of the generation process.
# FIXME(woosuk): Currently, the step method is inefficient because it
# creates RequestOutput objects for all running requests, while they
# may not be needed unless the output is streamed to the client.
return self.scheduler.has_unfinished_requests()

def step(self) -> None:
Copy link
Collaborator Author

@robertgshaw2-redhat robertgshaw2-redhat Oct 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This breaks the API, as step no longer returns RequestOutput.

However, I will note that this is already how we use the LLMEngine in the MQLLMEngine with async_process_outputs (the return value of step is [ ])

"""Schedule, execute, and send output to Detokenizer."""
if self.scheduler.has_unfinished_requests():
scheduler_output = self.scheduler.schedule()
output = self.model_executor.execute_model(scheduler_output)
sampled = self.scheduler.update_from_output(
scheduler_output, output)
self.send_to_detokenizer(sampled)
req_outputs = self.recv_from_detokenizer()
return req_outputs

def send_to_detokenizer(self, sampled: List[Tuple[Request, int]]) -> None:
inputs = DetokenizerInputs(
req_ids=[],
prompt_token_ids=[],
new_token_ids=[],
skip_special_tokens=[],
spaces_between_special_tokens=[],
free_req_ids=[], # TODO(woosuk): Implement freeing.
def _add_to_detokenizer(self, request: Request) -> None:
"""Create DetokenizerNewRequest and send to Detokenizer."""

new_request = DetokenizerNewRequest(
request_id=request.request_id,
prompt=request.prompt,
prompt_token_ids=request.prompt_token_ids,
skip_special_tokens=request.sampling_params.skip_special_tokens,
spaces_between_special_tokens=request.sampling_params.
spaces_between_special_tokens,
output_kind=request.sampling_params.output_kind,
)
for req, num_tokens in sampled:
inputs.req_ids.append(req.request_id)
if len(req.output_token_ids) == num_tokens:
# The request is first detokenized.
inputs.prompt_token_ids.append(req.prompt_token_ids)
else:
# The prompt token ids are already cached in the detokenizer.
inputs.prompt_token_ids.append([])
inputs.new_token_ids.append(req.output_token_ids[-num_tokens:])
inputs.skip_special_tokens.append(
req.sampling_params.skip_special_tokens)
inputs.spaces_between_special_tokens.append(
req.sampling_params.spaces_between_special_tokens)

# Update the number of lagged steps.
self.num_lagged_steps[req.request_id] += 1
self.detokenizer.send(inputs)

def recv_from_detokenizer(self) -> List[RequestOutput]:
detokenizer_output = self.detokenizer.recv()
if detokenizer_output is None:
return []

req_outputs: List[RequestOutput] = []
num_reqs = len(detokenizer_output.req_ids)
for i in range(num_reqs):
req_id = detokenizer_output.req_ids[i]
if req_id not in self.requests:
# The request has been aborted while the detokenizer was
# processing the outputs.
continue

req = self.requests[req_id]
req.output_text += detokenizer_output.detokenized_texts[i]

self.num_lagged_steps[req_id] -= 1
finished = (self.num_lagged_steps[req_id] == 0
and req.is_finished())
req_output = self._make_request_output(
req, detokenizer_output.num_output_token_ids[i],
detokenizer_output.detokenized_texts[i], finished)
req_outputs.append(req_output)

if finished:
self._free_request(req_id)
return req_outputs

self.detokenizer.add_request(new_request)

def send_to_detokenizer(self, sampled: List[Tuple[Request, int]]) -> None:
"""Send new tokens to Detokenizer."""

# TODO(robertgshaw2): We could avoid this conversion loop by either/or:
# - scheduler.update_from_output() creates DetokenizerInputData
# - serializing and sending the Requests directly to the Detokenizer
# The negative of this is that the Detokenizer is then more coupled.
input_data = [
DetokenizerInputData(
request_id=req.request_id,
new_token_ids=req.output_token_ids[-num_tokens:],
finished=req.is_finished(),
finish_reason=req.get_finished_reason(),
stop_reason=req.stop_reason) for req, num_tokens in sampled
]

self.detokenizer.send(DetokenizerInputs(data=input_data))

def terminate_detokenizer(self) -> None:
self.detokenizer.terminate()

def _make_request_output(
self,
Copy link
Collaborator Author

@robertgshaw2-redhat robertgshaw2-redhat Oct 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We reduce overhead because:

  • We now create the RequestOutputs in the Detokenizer
  • We now send the RequestOutputs to the MQLLMEngineClient in the Detokenizer, rather than from the MQLLMEngine (which is in the same process as the LLMEngine)

request: Request,
num_output_tokens: int,
new_output_text: str,
finished: bool,
) -> RequestOutput:
req_output = self.request_outputs.get(request.request_id)
if req_output is None:
# TODO: Support `n` > 1.
completion_output = CompletionOutput(
index=0,
text="",
token_ids=[],
cumulative_logprob=None,
logprobs=None, # TODO
finish_reason=None,
stop_reason=None,
lora_request=None,
)
req_output = RequestOutput(
request_id=request.request_id,
prompt=request.prompt,
prompt_token_ids=request.prompt_token_ids,
prompt_logprobs=None, # TODO
outputs=[completion_output],
finished=False,
metrics=None,
lora_request=None,
encoder_prompt=None,
encoder_prompt_token_ids=None,
)
self.request_outputs[request.request_id] = req_output

completion_output = req_output.outputs[0]
if request.sampling_params.output_kind == RequestOutputKind.CUMULATIVE:
completion_output.text += new_output_text
completion_output.token_ids = (
request.output_token_ids[:num_output_tokens])
elif request.sampling_params.output_kind == RequestOutputKind.DELTA:
completion_output.text = new_output_text
num_prev_tokens = len(completion_output.token_ids)
completion_output.token_ids = request.output_token_ids[
num_prev_tokens:num_output_tokens]
elif (request.sampling_params.output_kind ==
RequestOutputKind.FINAL_ONLY):
if finished:
completion_output.text = request.output_text
completion_output.token_ids = request.output_token_ids
else:
completion_output.text = ""
completion_output.token_ids = []

if finished:
completion_output.finish_reason = request.get_finished_reason()
completion_output.stop_reason = request.stop_reason
req_output.finished = finished
return req_output

def _free_request(self, request_id: str) -> None:
self.requests.pop(request_id, None)
self.num_lagged_steps.pop(request_id, None)
self.request_outputs.pop(request_id, None)

def check_health(self) -> None:
if self.tokenizer:
self.tokenizer.check_health()
Expand Down
1 change: 0 additions & 1 deletion vllm/v1/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def __init__(
self.prompt_token_ids = inputs["prompt_token_ids"]
self.num_prompt_tokens = len(self.prompt_token_ids)
self.output_token_ids: List[int] = []
self.output_text = ""
self.num_computed_tokens = 0

@property
Expand Down
Loading
Loading