From 8152b265b70873812239b6d072be341eecaad944 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 16 Jul 2024 13:37:10 -0700 Subject: [PATCH] [Misc] Log spec decode metrics (#6454) --- tests/metrics/test_metrics.py | 49 +++++++++++++++++++ tests/spec_decode/e2e/conftest.py | 44 ++++++++++++++--- .../e2e/test_multistep_correctness.py | 18 ++++--- vllm/engine/metrics.py | 40 +++++++++++++++ 4 files changed, 137 insertions(+), 14 deletions(-) diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index 0191d85194e33..5694061e17e02 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -168,6 +168,55 @@ def test_engine_log_metrics_regression( assert_metrics(engine, disable_log_stats, len(example_prompts)) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [10]) +def test_metric_spec_decode( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + k = 5 + + with vllm_runner(model, + dtype=dtype, + disable_log_stats=False, + gpu_memory_utilization=0.4, + speculative_model=model, + num_speculative_tokens=k, + use_v2_block_manager=True) as vllm_model: + + # Force log interval to be 0 to catch all metrics. + stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus'] + stat_logger.local_interval = 0 + + # Note that the purpose of this test is to verify spec decode + # metrics instead of functional correctness, so the expected values + # are intended to be loose. + metric_name_to_expected_fn = { + "gauge_spec_decode_draft_acceptance_rate": lambda v: 0 <= v <= 1, + "gauge_spec_decode_efficiency": lambda v: 0 <= v <= 1, + "counter_spec_decode_num_accepted_tokens": lambda v: 0 <= v <= k, + "counter_spec_decode_num_draft_tokens": lambda v: v == k, + "counter_spec_decode_num_emitted_tokens": + lambda v: 0 <= v <= k + 1, + } + + # Use one request to better inspect the metrics. + prompts = example_prompts[:1] + + _ = vllm_model.generate_greedy(prompts, max_tokens) + for metric_name, is_expected in metric_name_to_expected_fn.items(): + metric_val = getattr( + stat_logger.metrics, + metric_name).labels(**stat_logger.labels)._value.get() + assert is_expected(metric_val), ( + f"the value of metric {metric_name} ({metric_val}) " + "does not meet expectation") + + def assert_metrics(engine: LLMEngine, disable_log_stats: bool, num_requests: int) -> None: if disable_log_stats: diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index fb3415b5db153..34a6c9a393a58 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -162,6 +162,11 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs, } test_name = request.node.name + model = kwargs["model"] + draft_model = kwargs.get("speculative_model", None) + same_draft_target_model = (draft_model is not None + and draft_model == model) + def generator_inner(): wait_for_gpu_memory_to_clear( @@ -177,6 +182,13 @@ def generator_inner(): print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}') llm = AsyncLLM(**kwargs) if use_async else LLM(**kwargs) + + # Override logging interval to 0 for spec decode test run to + # log all metrics in time. + if (baseline_or_test == "test" and not use_async + and llm.llm_engine.log_stats): + for sate_logger in llm.llm_engine.stat_loggers.values(): + sate_logger.local_interval = 0 set_random_seed(seed) yield llm @@ -188,6 +200,9 @@ def generator_outer(): yield llm del llm + # Set an attribute to the generator_outer function to allow us to + # determine whether to further check the acceptance rate in tests. + generator_outer.same_draft_target_model = same_draft_target_model # type: ignore return generator_outer @@ -204,18 +219,26 @@ def maybe_assert_ngram_worker(llm): def get_output_from_llm_generator( llm_generator, prompts, - sampling_params) -> Tuple[List[str], List[List[int]]]: + sampling_params) -> Tuple[List[str], List[List[int]], float]: tokens: List[str] = [] token_ids: List[List[int]] = [] + acceptance_rate: float = -1.0 for llm in llm_generator(): maybe_assert_ngram_worker(llm) outputs = llm.generate(prompts, sampling_params, use_tqdm=True) token_ids = [output.outputs[0].token_ids for output in outputs] tokens = [output.outputs[0].text for output in outputs] + + # Fetch acceptance rate if logging is enabled. + if stat_loggers := getattr(llm.llm_engine, "stat_loggers", None): + stat_logger = stat_loggers["prometheus"] + acceptance_rate = (stat_logger.metrics. + gauge_spec_decode_draft_acceptance_rate.labels( + **stat_logger.labels)._value.get()) del llm - return tokens, token_ids + return tokens, token_ids, acceptance_rate def get_logprobs_from_llm_generator( @@ -237,7 +260,8 @@ def run_greedy_equality_correctness_test(baseline_llm_generator, batch_size, max_output_len, force_output_len: bool, - print_tokens: bool = False): + print_tokens: bool = False, + ensure_all_accepted: bool = False): """Helper method that compares the outputs of both the baseline LLM and the test LLM. It asserts greedy equality, e.g. that the outputs are exactly the same when temperature is zero. @@ -267,12 +291,13 @@ def run_greedy_equality_correctness_test(baseline_llm_generator, temperature=temperature, ) - spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator( - test_llm_generator, prompts, sampling_params) + (spec_batch_tokens, spec_batch_token_ids, + acceptance_rate) = get_output_from_llm_generator(test_llm_generator, + prompts, sampling_params) - (baseline_batch_tokens, - baseline_batch_token_ids) = get_output_from_llm_generator( - baseline_llm_generator, prompts, sampling_params) + (baseline_batch_tokens, baseline_batch_token_ids, + _) = get_output_from_llm_generator(baseline_llm_generator, prompts, + sampling_params) assert len(baseline_batch_token_ids) == len(prompts) assert len(spec_batch_token_ids) == len(prompts) @@ -287,3 +312,6 @@ def run_greedy_equality_correctness_test(baseline_llm_generator, print(f'{i=} {baseline_token_ids=}') print(f'{i=} {spec_token_ids=}') assert baseline_token_ids == spec_token_ids + + if ensure_all_accepted: + assert acceptance_rate == 1.0 diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index 94cc36f22875a..86cab7aba2380 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -97,7 +97,7 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator, temperature=temperature, ) - batch_tokens, batch_token_ids = get_output_from_llm_generator( + batch_tokens, batch_token_ids, _ = get_output_from_llm_generator( test_llm_generator, prompts, sampling_params) # Expect a generation for each prompt in the batch. @@ -200,12 +200,18 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1( Since this test is cheaper than other e2e correctness tests, we generate with a higher output_len. + + When the draft model is the same as the target model, we further check + whether all speculative tokens are accepted. """ - run_greedy_equality_correctness_test(baseline_llm_generator, - test_llm_generator, - batch_size, - max_output_len=output_len, - force_output_len=True) + ensure_all_accepted = test_llm_generator.same_draft_target_model + run_greedy_equality_correctness_test( + baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True, + ensure_all_accepted=ensure_all_accepted) @pytest.mark.parametrize( diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 48aec84298d86..9e187b2fbc33b 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -133,6 +133,30 @@ def __init__(self, labelnames: List[str], max_model_len: int): documentation="Count of successfully processed requests.", labelnames=labelnames + [Metrics.labelname_finish_reason]) + # Speculatie decoding stats + self.gauge_spec_decode_draft_acceptance_rate = self._base_library.Gauge( + name="vllm:spec_decode_draft_acceptance_rate", + documentation="Speulative token acceptance rate.", + labelnames=labelnames) + self.gauge_spec_decode_efficiency = self._base_library.Gauge( + name="vllm:spec_decode_efficiency", + documentation="Speculative decoding system efficiency.", + labelnames=labelnames) + self.counter_spec_decode_num_accepted_tokens = ( + self._base_library.Counter( + name="vllm:spec_decode_num_accepted_tokens_total", + documentation="Number of accepted tokens.", + labelnames=labelnames)) + self.counter_spec_decode_num_draft_tokens = self._base_library.Counter( + name="vllm:spec_decode_num_draft_tokens_total", + documentation="Number of draft tokens.", + labelnames=labelnames) + self.counter_spec_decode_num_emitted_tokens = ( + self._base_library.Counter( + name="vllm:spec_decode_num_emitted_tokens_total", + documentation="Number of emitted tokens.", + labelnames=labelnames)) + # Deprecated in favor of vllm:prompt_tokens_total self.gauge_avg_prompt_throughput = self._base_library.Gauge( name="vllm:avg_prompt_throughput_toks_per_s", @@ -454,6 +478,22 @@ def log(self, stats: Stats): self.num_generation_tokens = [] self.last_local_log = stats.now + if stats.spec_decode_metrics is not None: + self._log_gauge( + self.metrics.gauge_spec_decode_draft_acceptance_rate, + stats.spec_decode_metrics.draft_acceptance_rate) + self._log_gauge(self.metrics.gauge_spec_decode_efficiency, + stats.spec_decode_metrics.system_efficiency) + self._log_counter( + self.metrics.counter_spec_decode_num_accepted_tokens, + stats.spec_decode_metrics.accepted_tokens) + self._log_counter( + self.metrics.counter_spec_decode_num_draft_tokens, + stats.spec_decode_metrics.draft_tokens) + self._log_counter( + self.metrics.counter_spec_decode_num_emitted_tokens, + stats.spec_decode_metrics.emitted_tokens) + class RayPrometheusStatLogger(PrometheusStatLogger): """RayPrometheusStatLogger uses Ray metrics instead."""