From 56804528c8cd87238ce002badbe7d9d8ba696086 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 16 Jul 2024 09:40:24 -0700 Subject: [PATCH] fix test --- tests/spec_decode/e2e/conftest.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index ffb022044e130..dafb2422aa70f 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -185,7 +185,8 @@ def generator_inner(): # Override logging interval to 0 for spec decode test run to # log all metrics in time. - if baseline_or_test == "test" and not use_async: + if (baseline_or_test == "test" and not use_async + and llm.llm_engine.log_stats): for sate_logger in llm.llm_engine.stat_loggers.values(): sate_logger.local_interval = 0 set_random_seed(seed) @@ -228,10 +229,13 @@ def get_output_from_llm_generator( outputs = llm.generate(prompts, sampling_params, use_tqdm=True) token_ids = [output.outputs[0].token_ids for output in outputs] tokens = [output.outputs[0].text for output in outputs] - stat_logger = llm.llm_engine.stat_loggers["prometheus"] - acceptance_rate = ( - stat_logger.metrics.gauge_spec_decode_draft_acceptance_rate.labels( - **stat_logger.labels)._value.get()) + + # Fetch acceptance rate if logging is enabled. + if llm.llm_engine.log_stats: + stat_logger = llm.llm_engine.stat_loggers["prometheus"] + acceptance_rate = (stat_logger.metrics. + gauge_spec_decode_draft_acceptance_rate.labels( + **stat_logger.labels)._value.get()) del llm return tokens, token_ids, acceptance_rate