Skip to content

Commit

Permalink
[V1][Metrics] Add several request timing histograms (vllm-project#12644)
Browse files Browse the repository at this point in the history
Signed-off-by: Mark McLoughlin <[email protected]>
  • Loading branch information
markmc authored Feb 11, 2025
1 parent 110f59a commit 75e6e14
Show file tree
Hide file tree
Showing 16 changed files with 334 additions and 84 deletions.
31 changes: 31 additions & 0 deletions tests/entrypoints/openai/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ async def client(server):
"vllm:time_per_output_token_seconds":
[("_count", _NUM_REQUESTS * (_NUM_GENERATION_TOKENS_PER_REQUEST - 1))],
"vllm:e2e_request_latency_seconds": [("_count", _NUM_REQUESTS)],
"vllm:request_queue_time_seconds": [("_count", _NUM_REQUESTS)],
"vllm:request_inference_time_seconds": [("_count", _NUM_REQUESTS)],
"vllm:request_prefill_time_seconds": [("_count", _NUM_REQUESTS)],
"vllm:request_decode_time_seconds": [("_count", _NUM_REQUESTS)],
"vllm:request_prompt_tokens":
[("_sum", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST),
("_count", _NUM_REQUESTS)],
Expand Down Expand Up @@ -169,6 +173,18 @@ async def test_metrics_counts(server: RemoteOpenAIServer,
"vllm:e2e_request_latency_seconds_sum",
"vllm:e2e_request_latency_seconds_bucket",
"vllm:e2e_request_latency_seconds_count",
"vllm:request_queue_time_seconds_sum",
"vllm:request_queue_time_seconds_bucket",
"vllm:request_queue_time_seconds_count",
"vllm:request_inference_time_seconds_sum",
"vllm:request_inference_time_seconds_bucket",
"vllm:request_inference_time_seconds_count",
"vllm:request_prefill_time_seconds_sum",
"vllm:request_prefill_time_seconds_bucket",
"vllm:request_prefill_time_seconds_count",
"vllm:request_decode_time_seconds_sum",
"vllm:request_decode_time_seconds_bucket",
"vllm:request_decode_time_seconds_count",
"vllm:request_prompt_tokens_sum",
"vllm:request_prompt_tokens_bucket",
"vllm:request_prompt_tokens_count",
Expand Down Expand Up @@ -220,6 +236,21 @@ async def test_metrics_counts(server: RemoteOpenAIServer,
"vllm:time_per_output_token_seconds_sum",
"vllm:time_per_output_token_seconds_bucket",
"vllm:time_per_output_token_seconds_count",
"vllm:e2e_request_latency_seconds_sum",
"vllm:e2e_request_latency_seconds_bucket",
"vllm:e2e_request_latency_seconds_count",
"vllm:request_queue_time_seconds_sum",
"vllm:request_queue_time_seconds_bucket",
"vllm:request_queue_time_seconds_count",
"vllm:request_inference_time_seconds_sum",
"vllm:request_inference_time_seconds_bucket",
"vllm:request_inference_time_seconds_count",
"vllm:request_prefill_time_seconds_sum",
"vllm:request_prefill_time_seconds_bucket",
"vllm:request_prefill_time_seconds_count",
"vllm:request_decode_time_seconds_sum",
"vllm:request_decode_time_seconds_bucket",
"vllm:request_decode_time_seconds_count",
]


Expand Down
3 changes: 2 additions & 1 deletion tests/v1/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def create_scheduler(
return Scheduler(scheduler_config,
model_config,
cache_config,
lora_config=None)
lora_config=None,
log_stats=True)


def create_requests(
Expand Down
6 changes: 4 additions & 2 deletions tests/v1/engine/test_engine_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def test_engine_core(monkeypatch):
executor_class = Executor.get_class(vllm_config)

engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class)
executor_class=executor_class,
log_stats=True)
"""Test basic request lifecycle."""

# First request.
Expand Down Expand Up @@ -157,7 +158,8 @@ def test_engine_core_advanced_sampling(monkeypatch):
executor_class = Executor.get_class(vllm_config)

engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class)
executor_class=executor_class,
log_stats=True)
"""Test basic request lifecycle."""
# First request.
request: EngineCoreRequest = make_request()
Expand Down
2 changes: 2 additions & 0 deletions tests/v1/engine/test_engine_core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
asyncio_mode=False,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=False,
)

MAX_TOKENS = 20
Expand Down Expand Up @@ -163,6 +164,7 @@ async def test_engine_core_client_asyncio(monkeypatch):
asyncio_mode=True,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=True,
)

MAX_TOKENS = 20
Expand Down
23 changes: 15 additions & 8 deletions tests/v1/engine/test_output_processor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

import math
import time
from typing import Dict, List, Optional

import pytest
Expand All @@ -15,6 +16,7 @@
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.metrics.stats import IterationStats


def _ref_convert_id_to_token(
Expand Down Expand Up @@ -603,6 +605,7 @@ def test_iteration_stats(dummy_test_vectors):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
log_stats=True)
engine_core = MockEngineCore(dummy_test_vectors.generation_tokens)
engine_core_timestamp = time.monotonic()

# Make N requests.
requests = [
Expand Down Expand Up @@ -630,8 +633,9 @@ def test_iteration_stats(dummy_test_vectors):

# First iteration has 2 prefills.
outputs = engine_core.get_outputs()[:num_active]
processed_outputs = output_processor.process_outputs(outputs)
iteration_stats = processed_outputs.iteration_stats
iteration_stats = IterationStats()
output_processor.process_outputs(outputs, engine_core_timestamp,
iteration_stats)
total_prompt_tokens = sum([
len(prompt_tokens)
for prompt_tokens in dummy_test_vectors.prompt_tokens[:num_active]
Expand All @@ -642,8 +646,9 @@ def test_iteration_stats(dummy_test_vectors):

# Just decodes in this step.
outputs = engine_core.get_outputs()[:num_active]
processed_outputs = output_processor.process_outputs(outputs)
iteration_stats = processed_outputs.iteration_stats
iteration_stats = IterationStats()
output_processor.process_outputs(outputs, engine_core_timestamp,
iteration_stats)

assert iteration_stats.num_prompt_tokens == 0
assert iteration_stats.num_generation_tokens == num_active
Expand All @@ -652,17 +657,19 @@ def test_iteration_stats(dummy_test_vectors):
output_processor.add_request(inactive_request)
num_active += 1
outputs = engine_core.get_outputs()[:num_active]
processed_outputs = output_processor.process_outputs(outputs)
iteration_stats = processed_outputs.iteration_stats
iteration_stats = IterationStats()
output_processor.process_outputs(outputs, engine_core_timestamp,
iteration_stats)
total_prompt_tokens = len(dummy_test_vectors.prompt_tokens[num_active - 1])

assert iteration_stats.num_prompt_tokens == total_prompt_tokens
assert iteration_stats.num_generation_tokens == num_active

# Just decodes in this step.
outputs = engine_core.get_outputs()[:num_active]
processed_outputs = output_processor.process_outputs(outputs)
iteration_stats = processed_outputs.iteration_stats
iteration_stats = IterationStats()
output_processor.process_outputs(outputs, engine_core_timestamp,
iteration_stats)

assert iteration_stats.num_prompt_tokens == 0
assert iteration_stats.num_generation_tokens == num_active
3 changes: 3 additions & 0 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,16 @@ def __init__(
sliding_window: Optional[int] = None,
enable_caching: bool = True,
num_preallocate_tokens: int = 64,
log_stats: bool = False,
) -> None:
self.block_size = block_size
self.num_gpu_blocks = num_gpu_blocks
self.max_model_len = max_model_len
self.max_num_blocks_per_req = cdiv(max_model_len, block_size)
self.sliding_window = sliding_window
self.enable_caching = enable_caching
# FIXME: make prefix cache stats conditional on log_stats
self.log_stats = log_stats
# NOTE(woosuk): To avoid frequent block allocation, we preallocate some
# blocks for each request. For example, when a request reaches the end
# of its block table, we preallocate N blocks in advance. This way, we
Expand Down
33 changes: 29 additions & 4 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0

import time
from collections import deque
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union

Expand All @@ -10,7 +11,8 @@
from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.core.scheduler_output import (CachedRequestData, NewRequestData,
SchedulerOutput)
from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
EngineCoreOutput, EngineCoreOutputs)
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
Expand All @@ -26,10 +28,12 @@ def __init__(
model_config: ModelConfig,
cache_config: CacheConfig,
lora_config: Optional[LoRAConfig],
log_stats: bool,
) -> None:
self.scheduler_config = scheduler_config
self.cache_config = cache_config
self.lora_config = lora_config
self.log_stats = log_stats

# Scheduling constraints.
self.max_num_running_reqs = self.scheduler_config.max_num_seqs
Expand All @@ -45,7 +49,8 @@ def __init__(
num_gpu_blocks=num_gpu_blocks,
max_model_len=self.max_model_len,
sliding_window=self.cache_config.sliding_window,
enable_caching=self.cache_config.enable_prefix_caching)
enable_caching=self.cache_config.enable_prefix_caching,
log_stats=self.log_stats)
self.block_size = self.cache_config.block_size

# req_id -> Request
Expand Down Expand Up @@ -107,6 +112,8 @@ def schedule(self) -> "SchedulerOutput":
scheduled_encoder_inputs: Dict[str, List[int]] = {}
encoder_budget = self.max_num_encoder_input_tokens

scheduled_timestamp = time.monotonic()

# First, schedule the RUNNING requests.
req_index = 0
while req_index < len(self.running) and token_budget > 0:
Expand Down Expand Up @@ -246,6 +253,7 @@ def schedule(self) -> "SchedulerOutput":
self.running.append(request)
if request.status == RequestStatus.WAITING:
scheduled_new_reqs.append(request)
self.request_scheduled(request, scheduled_timestamp)
elif request.status == RequestStatus.PREEMPTED:
scheduled_resumed_reqs.append(request)
else:
Expand Down Expand Up @@ -508,7 +516,8 @@ def update_from_output(
finish_reason=request.get_finished_reason(),
new_logprobs=new_logprobs,
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
stop_reason=request.stop_reason))
stop_reason=request.stop_reason,
events=request.take_events()))

if not stopped:
new_running.append(request)
Expand Down Expand Up @@ -541,6 +550,7 @@ def _check_stop(self, request: Request) -> bool:
def add_request(self, request: Request) -> None:
self.waiting.append(request)
self.requests[request.request_id] = request
self.request_queued(request)

def finish_requests(
self,
Expand Down Expand Up @@ -588,7 +598,22 @@ def has_unfinished_requests(self) -> bool:
def reset_prefix_cache(self) -> bool:
return self.kv_cache_manager.reset_prefix_cache()

def make_stats(self) -> SchedulerStats:
def request_queued(self, request: Request):
if not self.log_stats:
return
request.events.append(
EngineCoreEvent.new_event(EngineCoreEventType.QUEUED))

def request_scheduled(self, request: Request, timestamp: float):
if not self.log_stats:
return
request.events.append(
EngineCoreEvent.new_event(EngineCoreEventType.SCHEDULED,
timestamp))

def make_stats(self) -> Optional[SchedulerStats]:
if not self.log_stats:
return None
return SchedulerStats(
num_running_reqs=len(self.running),
num_waiting_reqs=len(self.waiting),
Expand Down
33 changes: 32 additions & 1 deletion vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

import enum
import time
from typing import List, Optional, Union

import msgspec
Expand Down Expand Up @@ -60,6 +61,30 @@ class EngineCoreRequest(
lora_request: Optional[LoRARequest]


class EngineCoreEventType(enum.IntEnum):
"""The type of engine core request event."""
QUEUED = 1
SCHEDULED = 2


class EngineCoreEvent(msgspec.Struct):
"""A timestamped engine core event associated with a request.
The timestamp is a monotonic timestamps and is used for by the engine
frontend to calculate intervals between engine core events. These
timestamps should not be compared with timestamps from other processes.
"""
type: EngineCoreEventType
timestamp: float

@classmethod
def new_event(cls,
event_type: EngineCoreEventType,
timestamp: Optional[float] = None) -> "EngineCoreEvent":
timestamp = time.monotonic() if timestamp is None else timestamp
return cls(event_type, timestamp)


class EngineCoreOutput(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
Expand All @@ -74,6 +99,7 @@ class EngineCoreOutput(

finish_reason: Optional[FinishReason] = None
stop_reason: Union[int, str, None] = None
events: Optional[List[EngineCoreEvent]] = None

@property
def finished(self) -> bool:
Expand All @@ -91,7 +117,12 @@ class EngineCoreOutputs(

# [num_reqs]
outputs: List[EngineCoreOutput]
scheduler_stats: SchedulerStats
scheduler_stats: Optional[SchedulerStats]
timestamp: float = 0.0

def __post_init__(self):
if self.timestamp == 0.0:
self.timestamp = time.monotonic()


class EngineCoreRequestType(enum.Enum):
Expand Down
Loading

0 comments on commit 75e6e14

Please sign in to comment.