Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac committed Jul 15, 2024
1 parent b25c837 commit 8b67a8b
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions tests/metrics/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def test_engine_log_metrics_regression(


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [10])
def test_metric_spec_decode(
vllm_runner,
Expand All @@ -188,6 +188,10 @@ def test_metric_spec_decode(
num_speculative_tokens=k,
use_v2_block_manager=True) as vllm_model:

# Force log interval to be 0 to catch all metrics.
stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus']
stat_logger.local_interval = 0

# Note that the purpose of this test is to verify spec decode
# metrics instead of functional correctness, so the expected values
# are intended to be loose.
Expand All @@ -204,7 +208,6 @@ def test_metric_spec_decode(
prompts = example_prompts[:1]

_ = vllm_model.generate_greedy(prompts, max_tokens)
stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus']
for metric_name, is_expected in metric_name_to_expected_fn.items():
metric_val = getattr(
stat_logger.metrics,
Expand Down

0 comments on commit 8b67a8b

Please sign in to comment.