diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 45970a572c5ac..938430fe2a501 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -1,12 +1,9 @@ import inspect -import time -from collections import Counter from typing import Callable, Dict, List, Optional, TypeVar, Union, overload import torch import torch.nn as nn -import vllm.envs as envs from vllm.compilation.counter import compilation_counter from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.config import CompilationLevel, VllmConfig @@ -149,10 +146,6 @@ def _support_torch_compile( def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs) self.vllm_config = vllm_config - self.track_batchsize = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0 - self.batchsize_counter = Counter() - self.last_logging_time = 0 - self.batchsize_logging_interval = envs.VLLM_LOG_BATCHSIZE_INTERVAL # for CompilationLevel.DYNAMO_AS_IS , the upper level model runner # will handle the compilation, so we don't need to do anything here. self.do_not_compile = \ @@ -171,30 +164,6 @@ def __call__(self, *args, **kwargs): # torch.compiler.is_compiling() means we are inside the compilation # e.g. TPU has the compilation logic in model runner, so we don't # need to compile the model inside. - if self.track_batchsize: - tensors = [] - for arg in args + tuple(kwargs.values()): - if isinstance(arg, torch.Tensor): - tensors.append(arg) - elif isinstance(arg, IntermediateTensors): - for tensor in arg.tensors.values(): - tensors.append(tensor) - # ignore kv cache tensors and empty tensors - bs = [ - tensor.shape[0] for tensor in tensors - if len(tensor.shape) <= 2 and tensor.shape[0] >= 1 - ] - for b in bs: - assert b == bs[0] - self.batchsize_counter[bs[0]] += 1 - if time.monotonic( - ) - self.last_logging_time > self.batchsize_logging_interval: - self.last_logging_time = time.monotonic() - sorted_data = sorted(list(self.batchsize_counter.items()), - key=lambda x: x[1], - reverse=True) - logger.info("Batchsize distribution (batchsize, count): %s", - sorted_data) if self.do_not_compile or torch.compiler.is_compiling(): return self.forward(*args, **kwargs) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index aaa3e4bb3a1e8..ad631334dadb3 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -1,8 +1,19 @@ +import time +from collections import Counter from contextlib import contextmanager from dataclasses import dataclass from typing import Any, Dict, Optional +import vllm.envs as envs from vllm.config import VllmConfig +from vllm.logger import init_logger + +logger = init_logger(__name__) + +track_batchsize = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0 +batchsize_counter = Counter() +last_logging_time: float = 0 +batchsize_logging_interval = envs.VLLM_LOG_BATCHSIZE_INTERVAL @dataclass @@ -27,6 +38,18 @@ def get_forward_context() -> ForwardContext: def set_forward_context(context: Any, vllm_config: VllmConfig): """A context manager that stores the current forward context, can be attention metadata, etc.""" + global track_batchsize, batchsize_counter + global last_logging_time, batchsize_logging_interval + if track_batchsize and context is not None: + batchsize = context.num_prefill_tokens + context.num_decode_tokens + batchsize_counter[batchsize] += 1 + if time.monotonic() - last_logging_time > batchsize_logging_interval: + last_logging_time = time.monotonic() + sorted_data = sorted(list(batchsize_counter.items()), + key=lambda x: x[1], + reverse=True) + logger.info("Batchsize distribution (batchsize, count): %s", + sorted_data) global _forward_context prev_context = _forward_context _forward_context = ForwardContext(