Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix Ray Metrics API usage #6354

Merged
merged 5 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions tests/metrics/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import List

import pytest
import ray
from prometheus_client import REGISTRY

from vllm import EngineArgs, LLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.metrics import RayPrometheusStatLogger
from vllm.sampling_params import SamplingParams

MODELS = [
Expand Down Expand Up @@ -192,3 +194,55 @@ def assert_metrics(engine: LLMEngine, disable_log_stats: bool,
labels)
assert (
metric_value == num_requests), "Metrics should be collected"


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [16])
def test_engine_log_metrics_ray(
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
# This test is quite weak - it only checks that we can use
# RayPrometheusStatLogger without exceptions.
# Checking whether the metrics are actually emitted is unfortunately
# non-trivial.

# We have to run in a Ray task for Ray metrics to be emitted correctly
@ray.remote(num_gpus=1)
def _inner():

class _RayPrometheusStatLogger(RayPrometheusStatLogger):

def __init__(self, *args, **kwargs):
self._i = 0
super().__init__(*args, **kwargs)

def log(self, *args, **kwargs):
self._i += 1
return super().log(*args, **kwargs)

engine_args = EngineArgs(
model=model,
dtype=dtype,
disable_log_stats=False,
)
engine = LLMEngine.from_engine_args(engine_args)
logger = _RayPrometheusStatLogger(
local_interval=0.5,
labels=dict(model_name=engine.model_config.served_model_name),
max_model_len=engine.model_config.max_model_len)
engine.add_logger("ray", logger)
for i, prompt in enumerate(example_prompts):
engine.add_request(
f"request-id-{i}",
prompt,
SamplingParams(max_tokens=max_tokens),
)
while engine.has_unfinished_requests():
engine.step()
assert logger._i > 0, ".log must be called at least once"

ray.get(_inner.remote())
138 changes: 110 additions & 28 deletions vllm/engine/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,63 +30,63 @@
# begin-metrics-definitions
class Metrics:
labelname_finish_reason = "finished_reason"
_base_library = prometheus_client
_gauge_cls = prometheus_client.Gauge
_counter_cls = prometheus_client.Counter
_histogram_cls = prometheus_client.Histogram

def __init__(self, labelnames: List[str], max_model_len: int):
# Unregister any existing vLLM collectors
self._unregister_vllm_metrics()

# Config Information
self.info_cache_config = prometheus_client.Info(
name='vllm:cache_config',
documentation='information of cache_config')
self._create_info_cache_config()

# System stats
# Scheduler State
self.gauge_scheduler_running = self._base_library.Gauge(
self.gauge_scheduler_running = self._gauge_cls(
name="vllm:num_requests_running",
documentation="Number of requests currently running on GPU.",
labelnames=labelnames)
self.gauge_scheduler_waiting = self._base_library.Gauge(
self.gauge_scheduler_waiting = self._gauge_cls(
name="vllm:num_requests_waiting",
documentation="Number of requests waiting to be processed.",
labelnames=labelnames)
self.gauge_scheduler_swapped = self._base_library.Gauge(
self.gauge_scheduler_swapped = self._gauge_cls(
name="vllm:num_requests_swapped",
documentation="Number of requests swapped to CPU.",
labelnames=labelnames)
# KV Cache Usage in %
self.gauge_gpu_cache_usage = self._base_library.Gauge(
self.gauge_gpu_cache_usage = self._gauge_cls(
name="vllm:gpu_cache_usage_perc",
documentation="GPU KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames)
self.gauge_cpu_cache_usage = self._base_library.Gauge(
self.gauge_cpu_cache_usage = self._gauge_cls(
name="vllm:cpu_cache_usage_perc",
documentation="CPU KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames)

# Iteration stats
self.counter_num_preemption = self._base_library.Counter(
self.counter_num_preemption = self._counter_cls(
name="vllm:num_preemptions_total",
documentation="Cumulative number of preemption from the engine.",
labelnames=labelnames)
self.counter_prompt_tokens = self._base_library.Counter(
self.counter_prompt_tokens = self._counter_cls(
name="vllm:prompt_tokens_total",
documentation="Number of prefill tokens processed.",
labelnames=labelnames)
self.counter_generation_tokens = self._base_library.Counter(
self.counter_generation_tokens = self._counter_cls(
name="vllm:generation_tokens_total",
documentation="Number of generation tokens processed.",
labelnames=labelnames)
self.histogram_time_to_first_token = self._base_library.Histogram(
self.histogram_time_to_first_token = self._histogram_cls(
name="vllm:time_to_first_token_seconds",
documentation="Histogram of time to first token in seconds.",
labelnames=labelnames,
buckets=[
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
0.75, 1.0, 2.5, 5.0, 7.5, 10.0
])
self.histogram_time_per_output_token = self._base_library.Histogram(
self.histogram_time_per_output_token = self._histogram_cls(
name="vllm:time_per_output_token_seconds",
documentation="Histogram of time per output token in seconds.",
labelnames=labelnames,
Expand All @@ -97,67 +97,145 @@ def __init__(self, labelnames: List[str], max_model_len: int):

# Request stats
# Latency
self.histogram_e2e_time_request = self._base_library.Histogram(
self.histogram_e2e_time_request = self._histogram_cls(
name="vllm:e2e_request_latency_seconds",
documentation="Histogram of end to end request latency in seconds.",
labelnames=labelnames,
buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0])
# Metadata
self.histogram_num_prompt_tokens_request = self._base_library.Histogram(
self.histogram_num_prompt_tokens_request = self._histogram_cls(
name="vllm:request_prompt_tokens",
documentation="Number of prefill tokens processed.",
labelnames=labelnames,
buckets=build_1_2_5_buckets(max_model_len),
)
self.histogram_num_generation_tokens_request = \
self._base_library.Histogram(
self._histogram_cls(
name="vllm:request_generation_tokens",
documentation="Number of generation tokens processed.",
labelnames=labelnames,
buckets=build_1_2_5_buckets(max_model_len),
)
self.histogram_best_of_request = self._base_library.Histogram(
self.histogram_best_of_request = self._histogram_cls(
name="vllm:request_params_best_of",
documentation="Histogram of the best_of request parameter.",
labelnames=labelnames,
buckets=[1, 2, 5, 10, 20],
)
self.histogram_n_request = self._base_library.Histogram(
self.histogram_n_request = self._histogram_cls(
name="vllm:request_params_n",
documentation="Histogram of the n request parameter.",
labelnames=labelnames,
buckets=[1, 2, 5, 10, 20],
)
self.counter_request_success = self._base_library.Counter(
self.counter_request_success = self._counter_cls(
name="vllm:request_success_total",
documentation="Count of successfully processed requests.",
labelnames=labelnames + [Metrics.labelname_finish_reason])

# Deprecated in favor of vllm:prompt_tokens_total
self.gauge_avg_prompt_throughput = self._base_library.Gauge(
self.gauge_avg_prompt_throughput = self._gauge_cls(
name="vllm:avg_prompt_throughput_toks_per_s",
documentation="Average prefill throughput in tokens/s.",
labelnames=labelnames,
)
# Deprecated in favor of vllm:generation_tokens_total
self.gauge_avg_generation_throughput = self._base_library.Gauge(
self.gauge_avg_generation_throughput = self._gauge_cls(
name="vllm:avg_generation_throughput_toks_per_s",
documentation="Average generation throughput in tokens/s.",
labelnames=labelnames,
)

def _create_info_cache_config(self) -> None:
# Config Information
self.info_cache_config = prometheus_client.Info(
name='vllm:cache_config',
documentation='information of cache_config')

def _unregister_vllm_metrics(self) -> None:
for collector in list(self._base_library.REGISTRY._collector_to_names):
for collector in list(prometheus_client.REGISTRY._collector_to_names):
if hasattr(collector, "_name") and "vllm" in collector._name:
self._base_library.REGISTRY.unregister(collector)
prometheus_client.REGISTRY.unregister(collector)


# end-metrics-definitions


class _RayGaugeWrapper:
"""Wraps around ray.util.metrics.Gauge to provide same API as
prometheus_client.Gauge"""

def __init__(self,
name: str,
documentation: str = "",
labelnames: Optional[List[str]] = None):
labelnames_tuple = tuple(labelnames) if labelnames else None
self._gauge = ray_metrics.Gauge(name=name,
description=documentation,
tag_keys=labelnames_tuple)

def labels(self, **labels):
self._gauge.set_default_tags(labels)
return self

def set(self, value: Union[int, float]):
return self._gauge.set(value)


class _RayCounterWrapper:
"""Wraps around ray.util.metrics.Counter to provide same API as
prometheus_client.Counter"""

def __init__(self,
name: str,
documentation: str = "",
labelnames: Optional[List[str]] = None):
labelnames_tuple = tuple(labelnames) if labelnames else None
self._counter = ray_metrics.Counter(name=name,
Yard1 marked this conversation as resolved.
Show resolved Hide resolved
description=documentation,
tag_keys=labelnames_tuple)

def labels(self, **labels):
self._counter.set_default_tags(labels)
return self

def inc(self, value: Union[int, float] = 1.0):
if value == 0:
return
return self._counter.inc(value)


class _RayHistogramWrapper:
"""Wraps around ray.util.metrics.Histogram to provide same API as
prometheus_client.Histogram"""

def __init__(self,
name: str,
documentation: str = "",
labelnames: Optional[List[str]] = None,
buckets: Optional[List[float]] = None):
labelnames_tuple = tuple(labelnames) if labelnames else None
self._histogram = ray_metrics.Histogram(name=name,
description=documentation,
tag_keys=labelnames_tuple,
boundaries=buckets)

def labels(self, **labels):
self._histogram.set_default_tags(labels)
return self

def observe(self, value: Union[int, float]):
return self._histogram.observe(value)


class RayMetrics(Metrics):
"""
RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics.
Provides the same metrics as Metrics but uses Ray's util.metrics library.
"""
_base_library = ray_metrics
_gauge_cls = _RayGaugeWrapper
_counter_cls = _RayCounterWrapper
_histogram_cls = _RayHistogramWrapper

def __init__(self, labelnames: List[str], max_model_len: int):
if ray_metrics is None:
Expand All @@ -168,8 +246,9 @@ def _unregister_vllm_metrics(self) -> None:
# No-op on purpose
pass


# end-metrics-definitions
def _create_info_cache_config(self) -> None:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for fixing this as well!

# No-op on purpose
pass


def build_1_2_5_buckets(max_value: int) -> List[int]:
Expand Down Expand Up @@ -457,4 +536,7 @@ def log(self, stats: Stats):

class RayPrometheusStatLogger(PrometheusStatLogger):
"""RayPrometheusStatLogger uses Ray metrics instead."""
_metrics_cls = RayMetrics
_metrics_cls = RayMetrics

def info(self, type: str, obj: SupportsMetricsInfo) -> None:
return None
Loading