From 773327a52e42f272161dcbe630ef69d3715923dd Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 17 Jul 2024 12:40:10 -0700 Subject: [PATCH] [Bugfix] Fix Ray Metrics API usage (#6354) --- tests/metrics/test_metrics.py | 54 +++++++++++ vllm/engine/async_llm_engine.py | 19 ++++ vllm/engine/llm_engine.py | 2 + vllm/engine/metrics.py | 160 ++++++++++++++++++++++++-------- 4 files changed, 195 insertions(+), 40 deletions(-) diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index 5694061e17e02..42b15cd6c458e 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -1,11 +1,13 @@ from typing import List import pytest +import ray from prometheus_client import REGISTRY from vllm import EngineArgs, LLMEngine from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.metrics import RayPrometheusStatLogger from vllm.sampling_params import SamplingParams MODELS = [ @@ -241,3 +243,55 @@ def assert_metrics(engine: LLMEngine, disable_log_stats: bool, labels) assert ( metric_value == num_requests), "Metrics should be collected" + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [16]) +def test_engine_log_metrics_ray( + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + # This test is quite weak - it only checks that we can use + # RayPrometheusStatLogger without exceptions. + # Checking whether the metrics are actually emitted is unfortunately + # non-trivial. + + # We have to run in a Ray task for Ray metrics to be emitted correctly + @ray.remote(num_gpus=1) + def _inner(): + + class _RayPrometheusStatLogger(RayPrometheusStatLogger): + + def __init__(self, *args, **kwargs): + self._i = 0 + super().__init__(*args, **kwargs) + + def log(self, *args, **kwargs): + self._i += 1 + return super().log(*args, **kwargs) + + engine_args = EngineArgs( + model=model, + dtype=dtype, + disable_log_stats=False, + ) + engine = LLMEngine.from_engine_args(engine_args) + logger = _RayPrometheusStatLogger( + local_interval=0.5, + labels=dict(model_name=engine.model_config.served_model_name), + max_model_len=engine.model_config.max_model_len) + engine.add_logger("ray", logger) + for i, prompt in enumerate(example_prompts): + engine.add_request( + f"request-id-{i}", + prompt, + SamplingParams(max_tokens=max_tokens), + ) + while engine.has_unfinished_requests(): + engine.step() + assert logger._i > 0, ".log must be called at least once" + + ray.get(_inner.remote()) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index c1ad99f4deaa0..5b0a60d4f2a38 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -12,6 +12,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_timeout import asyncio_timeout from vllm.engine.llm_engine import LLMEngine +from vllm.engine.metrics import StatLoggerBase from vllm.executor.ray_utils import initialize_ray_cluster, ray from vllm.inputs import LLMInputs, PromptInputs from vllm.logger import init_logger @@ -392,6 +393,7 @@ def from_engine_args( engine_args: AsyncEngineArgs, start_engine_loop: bool = True, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, ) -> "AsyncLLMEngine": """Creates an async LLM engine from the engine arguments.""" # Create the engine configs. @@ -454,6 +456,7 @@ def from_engine_args( max_log_len=engine_args.max_log_len, start_engine_loop=start_engine_loop, usage_context=usage_context, + stat_loggers=stat_loggers, ) return engine @@ -969,3 +972,19 @@ async def is_tracing_enabled(self) -> bool: ) else: return self.engine.is_tracing_enabled() + + def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None: + if self.engine_use_ray: + ray.get( + self.engine.add_logger.remote( # type: ignore + logger_name=logger_name, logger=logger)) + else: + self.engine.add_logger(logger_name=logger_name, logger=logger) + + def remove_logger(self, logger_name: str) -> None: + if self.engine_use_ray: + ray.get( + self.engine.remove_logger.remote( # type: ignore + logger_name=logger_name)) + else: + self.engine.remove_logger(logger_name=logger_name) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 622221d2dd13e..68ca9a97a3c61 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -379,6 +379,7 @@ def from_engine_args( cls, engine_args: EngineArgs, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, ) -> "LLMEngine": """Creates an LLM engine from the engine arguments.""" # Create the engine configs. @@ -423,6 +424,7 @@ def from_engine_args( executor_class=executor_class, log_stats=not engine_args.disable_log_stats, usage_context=usage_context, + stat_loggers=stat_loggers, ) return engine diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 9e187b2fbc33b..4ed7da2377111 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -30,55 +30,55 @@ # begin-metrics-definitions class Metrics: labelname_finish_reason = "finished_reason" - _base_library = prometheus_client + _gauge_cls = prometheus_client.Gauge + _counter_cls = prometheus_client.Counter + _histogram_cls = prometheus_client.Histogram def __init__(self, labelnames: List[str], max_model_len: int): # Unregister any existing vLLM collectors self._unregister_vllm_metrics() # Config Information - self.info_cache_config = prometheus_client.Info( - name='vllm:cache_config', - documentation='information of cache_config') + self._create_info_cache_config() # System stats # Scheduler State - self.gauge_scheduler_running = self._base_library.Gauge( + self.gauge_scheduler_running = self._gauge_cls( name="vllm:num_requests_running", documentation="Number of requests currently running on GPU.", labelnames=labelnames) - self.gauge_scheduler_waiting = self._base_library.Gauge( + self.gauge_scheduler_waiting = self._gauge_cls( name="vllm:num_requests_waiting", documentation="Number of requests waiting to be processed.", labelnames=labelnames) - self.gauge_scheduler_swapped = self._base_library.Gauge( + self.gauge_scheduler_swapped = self._gauge_cls( name="vllm:num_requests_swapped", documentation="Number of requests swapped to CPU.", labelnames=labelnames) # KV Cache Usage in % - self.gauge_gpu_cache_usage = self._base_library.Gauge( + self.gauge_gpu_cache_usage = self._gauge_cls( name="vllm:gpu_cache_usage_perc", documentation="GPU KV-cache usage. 1 means 100 percent usage.", labelnames=labelnames) - self.gauge_cpu_cache_usage = self._base_library.Gauge( + self.gauge_cpu_cache_usage = self._gauge_cls( name="vllm:cpu_cache_usage_perc", documentation="CPU KV-cache usage. 1 means 100 percent usage.", labelnames=labelnames) # Iteration stats - self.counter_num_preemption = self._base_library.Counter( + self.counter_num_preemption = self._counter_cls( name="vllm:num_preemptions_total", documentation="Cumulative number of preemption from the engine.", labelnames=labelnames) - self.counter_prompt_tokens = self._base_library.Counter( + self.counter_prompt_tokens = self._counter_cls( name="vllm:prompt_tokens_total", documentation="Number of prefill tokens processed.", labelnames=labelnames) - self.counter_generation_tokens = self._base_library.Counter( + self.counter_generation_tokens = self._counter_cls( name="vllm:generation_tokens_total", documentation="Number of generation tokens processed.", labelnames=labelnames) - self.histogram_time_to_first_token = self._base_library.Histogram( + self.histogram_time_to_first_token = self._histogram_cls( name="vllm:time_to_first_token_seconds", documentation="Histogram of time to first token in seconds.", labelnames=labelnames, @@ -86,7 +86,7 @@ def __init__(self, labelnames: List[str], max_model_len: int): 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5, 0.75, 1.0, 2.5, 5.0, 7.5, 10.0 ]) - self.histogram_time_per_output_token = self._base_library.Histogram( + self.histogram_time_per_output_token = self._histogram_cls( name="vllm:time_per_output_token_seconds", documentation="Histogram of time per output token in seconds.", labelnames=labelnames, @@ -97,83 +97,157 @@ def __init__(self, labelnames: List[str], max_model_len: int): # Request stats # Latency - self.histogram_e2e_time_request = self._base_library.Histogram( + self.histogram_e2e_time_request = self._histogram_cls( name="vllm:e2e_request_latency_seconds", documentation="Histogram of end to end request latency in seconds.", labelnames=labelnames, buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0]) # Metadata - self.histogram_num_prompt_tokens_request = self._base_library.Histogram( + self.histogram_num_prompt_tokens_request = self._histogram_cls( name="vllm:request_prompt_tokens", documentation="Number of prefill tokens processed.", labelnames=labelnames, buckets=build_1_2_5_buckets(max_model_len), ) self.histogram_num_generation_tokens_request = \ - self._base_library.Histogram( + self._histogram_cls( name="vllm:request_generation_tokens", documentation="Number of generation tokens processed.", labelnames=labelnames, buckets=build_1_2_5_buckets(max_model_len), ) - self.histogram_best_of_request = self._base_library.Histogram( + self.histogram_best_of_request = self._histogram_cls( name="vllm:request_params_best_of", documentation="Histogram of the best_of request parameter.", labelnames=labelnames, buckets=[1, 2, 5, 10, 20], ) - self.histogram_n_request = self._base_library.Histogram( + self.histogram_n_request = self._histogram_cls( name="vllm:request_params_n", documentation="Histogram of the n request parameter.", labelnames=labelnames, buckets=[1, 2, 5, 10, 20], ) - self.counter_request_success = self._base_library.Counter( + self.counter_request_success = self._counter_cls( name="vllm:request_success_total", documentation="Count of successfully processed requests.", labelnames=labelnames + [Metrics.labelname_finish_reason]) # Speculatie decoding stats - self.gauge_spec_decode_draft_acceptance_rate = self._base_library.Gauge( + self.gauge_spec_decode_draft_acceptance_rate = self._gauge_cls( name="vllm:spec_decode_draft_acceptance_rate", documentation="Speulative token acceptance rate.", labelnames=labelnames) - self.gauge_spec_decode_efficiency = self._base_library.Gauge( + self.gauge_spec_decode_efficiency = self._gauge_cls( name="vllm:spec_decode_efficiency", documentation="Speculative decoding system efficiency.", labelnames=labelnames) - self.counter_spec_decode_num_accepted_tokens = ( - self._base_library.Counter( - name="vllm:spec_decode_num_accepted_tokens_total", - documentation="Number of accepted tokens.", - labelnames=labelnames)) - self.counter_spec_decode_num_draft_tokens = self._base_library.Counter( + self.counter_spec_decode_num_accepted_tokens = (self._counter_cls( + name="vllm:spec_decode_num_accepted_tokens_total", + documentation="Number of accepted tokens.", + labelnames=labelnames)) + self.counter_spec_decode_num_draft_tokens = self._counter_cls( name="vllm:spec_decode_num_draft_tokens_total", documentation="Number of draft tokens.", labelnames=labelnames) - self.counter_spec_decode_num_emitted_tokens = ( - self._base_library.Counter( - name="vllm:spec_decode_num_emitted_tokens_total", - documentation="Number of emitted tokens.", - labelnames=labelnames)) + self.counter_spec_decode_num_emitted_tokens = (self._counter_cls( + name="vllm:spec_decode_num_emitted_tokens_total", + documentation="Number of emitted tokens.", + labelnames=labelnames)) # Deprecated in favor of vllm:prompt_tokens_total - self.gauge_avg_prompt_throughput = self._base_library.Gauge( + self.gauge_avg_prompt_throughput = self._gauge_cls( name="vllm:avg_prompt_throughput_toks_per_s", documentation="Average prefill throughput in tokens/s.", labelnames=labelnames, ) # Deprecated in favor of vllm:generation_tokens_total - self.gauge_avg_generation_throughput = self._base_library.Gauge( + self.gauge_avg_generation_throughput = self._gauge_cls( name="vllm:avg_generation_throughput_toks_per_s", documentation="Average generation throughput in tokens/s.", labelnames=labelnames, ) + def _create_info_cache_config(self) -> None: + # Config Information + self.info_cache_config = prometheus_client.Info( + name='vllm:cache_config', + documentation='information of cache_config') + def _unregister_vllm_metrics(self) -> None: - for collector in list(self._base_library.REGISTRY._collector_to_names): + for collector in list(prometheus_client.REGISTRY._collector_to_names): if hasattr(collector, "_name") and "vllm" in collector._name: - self._base_library.REGISTRY.unregister(collector) + prometheus_client.REGISTRY.unregister(collector) + + +# end-metrics-definitions + + +class _RayGaugeWrapper: + """Wraps around ray.util.metrics.Gauge to provide same API as + prometheus_client.Gauge""" + + def __init__(self, + name: str, + documentation: str = "", + labelnames: Optional[List[str]] = None): + labelnames_tuple = tuple(labelnames) if labelnames else None + self._gauge = ray_metrics.Gauge(name=name, + description=documentation, + tag_keys=labelnames_tuple) + + def labels(self, **labels): + self._gauge.set_default_tags(labels) + return self + + def set(self, value: Union[int, float]): + return self._gauge.set(value) + + +class _RayCounterWrapper: + """Wraps around ray.util.metrics.Counter to provide same API as + prometheus_client.Counter""" + + def __init__(self, + name: str, + documentation: str = "", + labelnames: Optional[List[str]] = None): + labelnames_tuple = tuple(labelnames) if labelnames else None + self._counter = ray_metrics.Counter(name=name, + description=documentation, + tag_keys=labelnames_tuple) + + def labels(self, **labels): + self._counter.set_default_tags(labels) + return self + + def inc(self, value: Union[int, float] = 1.0): + if value == 0: + return + return self._counter.inc(value) + + +class _RayHistogramWrapper: + """Wraps around ray.util.metrics.Histogram to provide same API as + prometheus_client.Histogram""" + + def __init__(self, + name: str, + documentation: str = "", + labelnames: Optional[List[str]] = None, + buckets: Optional[List[float]] = None): + labelnames_tuple = tuple(labelnames) if labelnames else None + self._histogram = ray_metrics.Histogram(name=name, + description=documentation, + tag_keys=labelnames_tuple, + boundaries=buckets) + + def labels(self, **labels): + self._histogram.set_default_tags(labels) + return self + + def observe(self, value: Union[int, float]): + return self._histogram.observe(value) class RayMetrics(Metrics): @@ -181,7 +255,9 @@ class RayMetrics(Metrics): RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics. Provides the same metrics as Metrics but uses Ray's util.metrics library. """ - _base_library = ray_metrics + _gauge_cls = _RayGaugeWrapper + _counter_cls = _RayCounterWrapper + _histogram_cls = _RayHistogramWrapper def __init__(self, labelnames: List[str], max_model_len: int): if ray_metrics is None: @@ -192,8 +268,9 @@ def _unregister_vllm_metrics(self) -> None: # No-op on purpose pass - -# end-metrics-definitions + def _create_info_cache_config(self) -> None: + # No-op on purpose + pass def build_1_2_5_buckets(max_value: int) -> List[int]: @@ -498,3 +575,6 @@ def log(self, stats: Stats): class RayPrometheusStatLogger(PrometheusStatLogger): """RayPrometheusStatLogger uses Ray metrics instead.""" _metrics_cls = RayMetrics + + def info(self, type: str, obj: SupportsMetricsInfo) -> None: + return None