Skip to content

Commit

Permalink
Optimize workflow w/ kv cache
Browse files Browse the repository at this point in the history
optimize model runner workflow to support bypassing the sampling step for intermediate chunks that are fully hit on kv cache

Signed-off-by: Haiyang Shi <[email protected]>
  • Loading branch information
Haiyang Shi committed Nov 6, 2024
1 parent 25f7781 commit 1a58023
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 12 deletions.
17 changes: 10 additions & 7 deletions vllm/model_executor/sampling_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,13 +280,16 @@ def _prepare_seq_groups(
num_prompts += 1
num_prefill_sample = len(seq_ids)
assert num_prefill_sample == 1
assert query_lens is not None and seq_lens is not None
query_len, seq_len = query_lens[i], seq_lens[i]
# If we need sampling, exclude num_prefill_sample tokens from
# prompt logprob.
prompt_logprob_len = (query_len - num_prefill_sample
if do_sample else query_len)
sample_len = num_prefill_sample if do_sample else 0
if query_lens is not None and seq_lens is not None:
query_len, seq_len = query_lens[i], seq_lens[i]
# If we need sampling, exclude num_prefill_sample tokens
# from prompt logprob.
prompt_logprob_len = (query_len - num_prefill_sample
if do_sample else query_len)
sample_len = num_prefill_sample if do_sample else 0
else:
prompt_logprob_len = 0
sample_len = 0
else:
# Decode
prompt_logprob_len = 0
Expand Down
5 changes: 5 additions & 0 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,6 +942,11 @@ def prompt_adapter_num_virtual_tokens(self) -> int:
return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \
if self.prompt_adapter_request else 0

@property
def is_sampling_enabled(self) -> bool:
return self.sampling_params.prompt_logprobs is not None \
or self.do_sample

def apply_delta(self,
sequence_group_metadata_delta: SequenceGroupMetadataDelta):
for id, delta in sequence_group_metadata_delta.seq_data_delta.items():
Expand Down
18 changes: 17 additions & 1 deletion vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@
from vllm.prompt_adapter.worker_manager import (
LRUCacheWorkerPromptAdapterManager)
from vllm.sampling_params import SamplingParams
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
SequenceGroupMetadata)
from vllm.utils import (CudaMemoryProfiler, PyObjectCache, async_tensor_h2d,
flatten_2d_lists, is_hip, is_pin_memory_available,
get_kv_cache_torch_dtype, supports_dynamo)
Expand Down Expand Up @@ -1564,6 +1565,21 @@ def execute_model(

self.attn_state.begin_forward(model_input)

# When using a KV cache with chunk-prefill enabled and sampling not
# explictly enabled, for the first 'n' chunks (except the last chunk
# before the decode phase), if there is a full hit in the KV cache,
# all KV tensors are fetched from the cache, and input tokens are
# set to None. Thus, if we encounter None for input_tokens, we just
# skip sampling and return an empty outputs.
if model_input.input_tokens is None:
outputs=[
CompletionSequenceGroupOutput(
samples=[],
prompt_logprobs=None
)
]
return [SamplerOutput(outputs=outputs)]

# Currently cuda graph is only supported by the decode phase.
assert model_input.attn_metadata is not None
prefill_meta = model_input.attn_metadata.prefill_metadata
Expand Down
8 changes: 4 additions & 4 deletions vllm/worker/vineyard_llm_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def prefetch_seq_kv_caches(
tokens = seq_data.get_prompt_token_ids()

# leave at least one token unmatched
token_chunk_size -= 1
# token_chunk_size -= 1

# alignment `context_len` to `self.chunk_size`
query_context_len = context_len - context_len % self.chunk_size
Expand Down Expand Up @@ -198,6 +198,9 @@ def prefetch_seq_kv_caches(
duration = time.perf_counter() - start_time
self.metrics.time_query.append(duration)
self.metrics.normalized_time_query.append(duration/len(tokens))
# no need to minus 1 one more time. matched = min(matched, token_chunk_size - 1)
if seq_group_metadata is not None and seq_group_metadata.is_sampling_enabled:
matched = min(matched, token_chunk_size - 1)
# synchronized across tensor parallel ranks
matched_tensor = torch.tensor([matched], dtype=torch.long, device='cuda')
# torch.distributed.all_reduce(matched_tensor, op=torch.distributed.ReduceOp.MIN,
Expand All @@ -208,9 +211,6 @@ def prefetch_seq_kv_caches(
offset = context_len % self.chunk_size
matched -= offset

# we force to use token_chunk_size - 1 to trigger KV recomputation
# TODO: this should be revisited later. We are looking for solutions to fully avoid computation.
matched = min(matched, token_chunk_size - 1)
if matched <= 0:
return seq_id, 0
if seq_group_metadata is not None:
Expand Down

0 comments on commit 1a58023

Please sign in to comment.