Skip to content

Commit

Permalink
comment
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao committed Dec 10, 2024
1 parent 38763d2 commit d741330
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 2 deletions.
11 changes: 9 additions & 2 deletions vllm/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,18 @@ def get_forward_context() -> ForwardContext:
@contextmanager
def set_forward_context(context: Any, vllm_config: VllmConfig):
"""A context manager that stores the current forward context,
can be attention metadata, etc."""
can be attention metadata, etc.
Here we can inject common logic for every model forward pass.
"""
global track_batchsize, batchsize_counter
global last_logging_time, batchsize_logging_interval
if track_batchsize and context is not None:
batchsize = context.num_prefill_tokens + context.num_decode_tokens
if hasattr(context, "num_prefill_tokens"):
# for v0 attention backends
batchsize = context.num_prefill_tokens + context.num_decode_tokens
else:
# for v1 attention backends
batchsize = context.num_input_tokens
batchsize_counter[batchsize] += 1
if time.monotonic() - last_logging_time > batchsize_logging_interval:
last_logging_time = time.monotonic()
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class FlashAttentionMetadata:
# |-- query_len ---|

num_actual_tokens: int # Number of tokens excluding padding.
num_input_tokens: int = 0 # Number of tokens including padding.
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,8 @@ def execute_model(
# Eager mode.
num_input_tokens = num_scheduled_tokens

attn_metadata.num_input_tokens = num_input_tokens

# Get the inputs embeds.
if encoder_outputs:
inputs_embeds = self.model.get_input_embeddings(
Expand Down

0 comments on commit d741330

Please sign in to comment.