diff --git a/tests/spec_decode/test_metrics.py b/tests/spec_decode/test_metrics.py index 2918fabddc900..7477486a3388d 100644 --- a/tests/spec_decode/test_metrics.py +++ b/tests/spec_decode/test_metrics.py @@ -105,6 +105,49 @@ def test_noop_until_time(): assert metrics is not None +def test_timer_is_reset(): + """Verify that the internal timer inside AsyncMetricsCollector + is reset after collection. + """ + spec_decode_sampler = MagicMock() + spec_decode_sampler.num_accepted_tokens = torch.tensor(0, + dtype=torch.long, + device='cuda') + spec_decode_sampler.num_emitted_tokens = torch.tensor(0, + dtype=torch.long, + device='cuda') + spec_decode_sampler.num_draft_tokens = 0 + + collect_interval_s = 5.0 + timer = MagicMock() + timer.side_effect = [ + 0.0, + collect_interval_s + 0.1, + collect_interval_s + 0.1, + collect_interval_s + 0.2, + collect_interval_s + 0.2, + 2 * collect_interval_s + 0.1, + 2 * collect_interval_s + 0.1, + ] + + collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler, + timer=timer, + collect_interval_s=collect_interval_s) + collector.init_gpu_tensors(rank=0) + + _ = collector.maybe_collect_rejsample_metrics(k=5) + metrics = collector.maybe_collect_rejsample_metrics(k=5) + assert metrics is not None + + _ = collector.maybe_collect_rejsample_metrics(k=5) + metrics = collector.maybe_collect_rejsample_metrics(k=5) + assert metrics is None + + _ = collector.maybe_collect_rejsample_metrics(k=5) + metrics = collector.maybe_collect_rejsample_metrics(k=5) + assert metrics is not None + + @pytest.mark.parametrize("has_data", [True, False]) def test_initial_metrics_has_correct_values(has_data: bool): """Test correctness of metrics data. diff --git a/vllm/spec_decode/metrics.py b/vllm/spec_decode/metrics.py index 2c4ae0b22744b..9036d117041f0 100644 --- a/vllm/spec_decode/metrics.py +++ b/vllm/spec_decode/metrics.py @@ -145,6 +145,10 @@ def _collect_rejsample_metrics( """ ready_event.synchronize() + + # update time of last collection + self._last_metrics_collect_time = self._timer() + accepted_tokens = self._aggregate_num_accepted_tokens.item() emitted_tokens = self._aggregate_num_emitted_tokens.item() draft_tokens = self._aggregate_num_draft_tokens