diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index 92e6086e312f7..7a361ef320810 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -84,6 +84,45 @@ def test_metric_counter_generation_tokens( f"metric: {metric_count!r}") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("max_tokens", [128, 129]) +@pytest.mark.parametrize("disable_async_output_proc", [True, False]) +def test_metric_counter_generation_tokens_multi_step( + vllm_runner, + example_prompts, + model: str, + max_tokens: int, + disable_async_output_proc: bool, +) -> None: + num_scheduler_steps = 8 + with vllm_runner( + model, + disable_log_stats=False, + gpu_memory_utilization=0.4, + num_scheduler_steps=num_scheduler_steps, + disable_async_output_proc=disable_async_output_proc, + ) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + tokenizer = vllm_model.model.get_tokenizer() + stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus'] + metric_count = stat_logger.metrics.counter_generation_tokens.labels( + **stat_logger.labels)._value.get() + vllm_generation_count = 0 + for i in range(len(example_prompts)): + vllm_output_ids, vllm_output_str = vllm_outputs[i] + prompt_ids = tokenizer.encode(example_prompts[i]) + # vllm_output_ids contains both prompt tokens and generation tokens. + # We're interested only in the count of the generation tokens. + vllm_generation_count += len(vllm_output_ids) - len(prompt_ids) + + # The multi-step scheduling will continue to execute forward even when + # encountering EOS, leading to slightly imprecise metrics. + assert abs(vllm_generation_count - metric_count) <\ + len(example_prompts) * num_scheduler_steps, \ + (f"generation token count: {vllm_generation_count!r}\n" + f"metric: {metric_count!r}") + + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize( diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 3a29e6a9ae094..99beea932882d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1718,6 +1718,15 @@ def _get_stats(self, # TPOTs. latency = seq_group.get_last_latency(now) time_per_output_tokens_iter.append(latency) + if seq_group.state.current_step == 0: + # For async_output_proc, the do_log_stats() + # is called following init_multi_step(), which + # sets the current_step to zero. + actual_num_batched_tokens +=\ + seq_group.state.num_steps - 1 + else: + actual_num_batched_tokens +=\ + seq_group.state.current_step - 1 # Because of chunked prefill, we can have a single sequence # group that does multiple prompt_runs. To prevent logging