Skip to content

Commit

Permalink
[Bugfix] [SpecDecode] AsyncMetricsCollector: update time since last c…
Browse files Browse the repository at this point in the history
…ollection (#6578)
  • Loading branch information
tdoublep authored Jul 19, 2024
1 parent 30efe41 commit f0bbfaf
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 0 deletions.
43 changes: 43 additions & 0 deletions tests/spec_decode/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,49 @@ def test_noop_until_time():
assert metrics is not None


def test_timer_is_reset():
"""Verify that the internal timer inside AsyncMetricsCollector
is reset after collection.
"""
spec_decode_sampler = MagicMock()
spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
spec_decode_sampler.num_draft_tokens = 0

collect_interval_s = 5.0
timer = MagicMock()
timer.side_effect = [
0.0,
collect_interval_s + 0.1,
collect_interval_s + 0.1,
collect_interval_s + 0.2,
collect_interval_s + 0.2,
2 * collect_interval_s + 0.1,
2 * collect_interval_s + 0.1,
]

collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
timer=timer,
collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0)

_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is not None

_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is None

_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is not None


@pytest.mark.parametrize("has_data", [True, False])
def test_initial_metrics_has_correct_values(has_data: bool):
"""Test correctness of metrics data.
Expand Down
4 changes: 4 additions & 0 deletions vllm/spec_decode/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ def _collect_rejsample_metrics(
"""

ready_event.synchronize()

# update time of last collection
self._last_metrics_collect_time = self._timer()

accepted_tokens = self._aggregate_num_accepted_tokens.item()
emitted_tokens = self._aggregate_num_emitted_tokens.item()
draft_tokens = self._aggregate_num_draft_tokens
Expand Down

0 comments on commit f0bbfaf

Please sign in to comment.