diff --git a/vllm/forward_context.py b/vllm/forward_context.py index ad631334dadb3..2fd82ca6b606d 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -37,11 +37,18 @@ def get_forward_context() -> ForwardContext: @contextmanager def set_forward_context(context: Any, vllm_config: VllmConfig): """A context manager that stores the current forward context, - can be attention metadata, etc.""" + can be attention metadata, etc. + Here we can inject common logic for every model forward pass. + """ global track_batchsize, batchsize_counter global last_logging_time, batchsize_logging_interval if track_batchsize and context is not None: - batchsize = context.num_prefill_tokens + context.num_decode_tokens + if hasattr(context, "num_prefill_tokens"): + # for v0 attention backends + batchsize = context.num_prefill_tokens + context.num_decode_tokens + else: + # for v1 attention backends + batchsize = context.num_input_tokens batchsize_counter[batchsize] += 1 if time.monotonic() - last_logging_time > batchsize_logging_interval: last_logging_time = time.monotonic() diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 251a103e60f06..5b64b23840af7 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -50,6 +50,7 @@ class FlashAttentionMetadata: # |-- query_len ---| num_actual_tokens: int # Number of tokens excluding padding. + num_input_tokens: int = 0 # Number of tokens including padding. max_query_len: int query_start_loc: torch.Tensor max_seq_len: int diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0a5adfb28c9bd..a3335fa838352 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -445,6 +445,8 @@ def execute_model( # Eager mode. num_input_tokens = num_scheduled_tokens + attn_metadata.num_input_tokens = num_input_tokens + # Get the inputs embeds. if encoder_outputs: inputs_embeds = self.model.get_input_embeddings(