Skip to content

Commit

Permalink
move to forward context
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao committed Dec 10, 2024
1 parent 96fb020 commit 38763d2
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 31 deletions.
31 changes: 0 additions & 31 deletions vllm/compilation/decorators.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import inspect
import time
from collections import Counter
from typing import Callable, Dict, List, Optional, TypeVar, Union, overload

import torch
import torch.nn as nn

import vllm.envs as envs
from vllm.compilation.counter import compilation_counter
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import CompilationLevel, VllmConfig
Expand Down Expand Up @@ -149,10 +146,6 @@ def _support_torch_compile(
def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
self.vllm_config = vllm_config
self.track_batchsize = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
self.batchsize_counter = Counter()
self.last_logging_time = 0
self.batchsize_logging_interval = envs.VLLM_LOG_BATCHSIZE_INTERVAL
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
# will handle the compilation, so we don't need to do anything here.
self.do_not_compile = \
Expand All @@ -171,30 +164,6 @@ def __call__(self, *args, **kwargs):
# torch.compiler.is_compiling() means we are inside the compilation
# e.g. TPU has the compilation logic in model runner, so we don't
# need to compile the model inside.
if self.track_batchsize:
tensors = []
for arg in args + tuple(kwargs.values()):
if isinstance(arg, torch.Tensor):
tensors.append(arg)
elif isinstance(arg, IntermediateTensors):
for tensor in arg.tensors.values():
tensors.append(tensor)
# ignore kv cache tensors and empty tensors
bs = [
tensor.shape[0] for tensor in tensors
if len(tensor.shape) <= 2 and tensor.shape[0] >= 1
]
for b in bs:
assert b == bs[0]
self.batchsize_counter[bs[0]] += 1
if time.monotonic(
) - self.last_logging_time > self.batchsize_logging_interval:
self.last_logging_time = time.monotonic()
sorted_data = sorted(list(self.batchsize_counter.items()),
key=lambda x: x[1],
reverse=True)
logger.info("Batchsize distribution (batchsize, count): %s",
sorted_data)
if self.do_not_compile or torch.compiler.is_compiling():
return self.forward(*args, **kwargs)

Expand Down
23 changes: 23 additions & 0 deletions vllm/forward_context.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
import time
from collections import Counter
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Dict, Optional

import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.logger import init_logger

logger = init_logger(__name__)

track_batchsize = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
batchsize_counter = Counter()
last_logging_time: float = 0
batchsize_logging_interval = envs.VLLM_LOG_BATCHSIZE_INTERVAL


@dataclass
Expand All @@ -27,6 +38,18 @@ def get_forward_context() -> ForwardContext:
def set_forward_context(context: Any, vllm_config: VllmConfig):
"""A context manager that stores the current forward context,
can be attention metadata, etc."""
global track_batchsize, batchsize_counter
global last_logging_time, batchsize_logging_interval
if track_batchsize and context is not None:
batchsize = context.num_prefill_tokens + context.num_decode_tokens
batchsize_counter[batchsize] += 1
if time.monotonic() - last_logging_time > batchsize_logging_interval:
last_logging_time = time.monotonic()
sorted_data = sorted(list(batchsize_counter.items()),
key=lambda x: x[1],
reverse=True)
logger.info("Batchsize distribution (batchsize, count): %s",
sorted_data)
global _forward_context
prev_context = _forward_context
_forward_context = ForwardContext(
Expand Down

0 comments on commit 38763d2

Please sign in to comment.