From 1a58023c03ce37ae3b2bb84c73558ee603dd0263 Mon Sep 17 00:00:00 2001 From: Haiyang Shi Date: Tue, 5 Nov 2024 11:52:36 -0800 Subject: [PATCH] Optimize workflow w/ kv cache optimize model runner workflow to support bypassing the sampling step for intermediate chunks that are fully hit on kv cache Signed-off-by: Haiyang Shi --- vllm/model_executor/sampling_metadata.py | 17 ++++++++++------- vllm/sequence.py | 5 +++++ vllm/worker/model_runner.py | 18 +++++++++++++++++- vllm/worker/vineyard_llm_cache.py | 8 ++++---- 4 files changed, 36 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index a085779bc61a7..039044f82ee33 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -280,13 +280,16 @@ def _prepare_seq_groups( num_prompts += 1 num_prefill_sample = len(seq_ids) assert num_prefill_sample == 1 - assert query_lens is not None and seq_lens is not None - query_len, seq_len = query_lens[i], seq_lens[i] - # If we need sampling, exclude num_prefill_sample tokens from - # prompt logprob. - prompt_logprob_len = (query_len - num_prefill_sample - if do_sample else query_len) - sample_len = num_prefill_sample if do_sample else 0 + if query_lens is not None and seq_lens is not None: + query_len, seq_len = query_lens[i], seq_lens[i] + # If we need sampling, exclude num_prefill_sample tokens + # from prompt logprob. + prompt_logprob_len = (query_len - num_prefill_sample + if do_sample else query_len) + sample_len = num_prefill_sample if do_sample else 0 + else: + prompt_logprob_len = 0 + sample_len = 0 else: # Decode prompt_logprob_len = 0 diff --git a/vllm/sequence.py b/vllm/sequence.py index 135586831e680..f57b7d677d3ef 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -942,6 +942,11 @@ def prompt_adapter_num_virtual_tokens(self) -> int: return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \ if self.prompt_adapter_request else 0 + @property + def is_sampling_enabled(self) -> bool: + return self.sampling_params.prompt_logprobs is not None \ + or self.do_sample + def apply_delta(self, sequence_group_metadata_delta: SequenceGroupMetadataDelta): for id, delta in sequence_group_metadata_delta.seq_data_delta.items(): diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index df359cc9e649a..8f3fcd1eecc44 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -45,7 +45,8 @@ from vllm.prompt_adapter.worker_manager import ( LRUCacheWorkerPromptAdapterManager) from vllm.sampling_params import SamplingParams -from vllm.sequence import IntermediateTensors, SequenceGroupMetadata +from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, + SequenceGroupMetadata) from vllm.utils import (CudaMemoryProfiler, PyObjectCache, async_tensor_h2d, flatten_2d_lists, is_hip, is_pin_memory_available, get_kv_cache_torch_dtype, supports_dynamo) @@ -1564,6 +1565,21 @@ def execute_model( self.attn_state.begin_forward(model_input) + # When using a KV cache with chunk-prefill enabled and sampling not + # explictly enabled, for the first 'n' chunks (except the last chunk + # before the decode phase), if there is a full hit in the KV cache, + # all KV tensors are fetched from the cache, and input tokens are + # set to None. Thus, if we encounter None for input_tokens, we just + # skip sampling and return an empty outputs. + if model_input.input_tokens is None: + outputs=[ + CompletionSequenceGroupOutput( + samples=[], + prompt_logprobs=None + ) + ] + return [SamplerOutput(outputs=outputs)] + # Currently cuda graph is only supported by the decode phase. assert model_input.attn_metadata is not None prefill_meta = model_input.attn_metadata.prefill_metadata diff --git a/vllm/worker/vineyard_llm_cache.py b/vllm/worker/vineyard_llm_cache.py index e5bd86cf33e22..d665736eaf4c9 100644 --- a/vllm/worker/vineyard_llm_cache.py +++ b/vllm/worker/vineyard_llm_cache.py @@ -158,7 +158,7 @@ def prefetch_seq_kv_caches( tokens = seq_data.get_prompt_token_ids() # leave at least one token unmatched - token_chunk_size -= 1 + # token_chunk_size -= 1 # alignment `context_len` to `self.chunk_size` query_context_len = context_len - context_len % self.chunk_size @@ -198,6 +198,9 @@ def prefetch_seq_kv_caches( duration = time.perf_counter() - start_time self.metrics.time_query.append(duration) self.metrics.normalized_time_query.append(duration/len(tokens)) + # no need to minus 1 one more time. matched = min(matched, token_chunk_size - 1) + if seq_group_metadata is not None and seq_group_metadata.is_sampling_enabled: + matched = min(matched, token_chunk_size - 1) # synchronized across tensor parallel ranks matched_tensor = torch.tensor([matched], dtype=torch.long, device='cuda') # torch.distributed.all_reduce(matched_tensor, op=torch.distributed.ReduceOp.MIN, @@ -208,9 +211,6 @@ def prefetch_seq_kv_caches( offset = context_len % self.chunk_size matched -= offset - # we force to use token_chunk_size - 1 to trigger KV recomputation - # TODO: this should be revisited later. We are looking for solutions to fully avoid computation. - matched = min(matched, token_chunk_size - 1) if matched <= 0: return seq_id, 0 if seq_group_metadata is not None: