From 367d1d166f8fce56251980bc78172098c3c5275d Mon Sep 17 00:00:00 2001 From: Haiyang Shi Date: Mon, 16 Dec 2024 10:52:02 -0800 Subject: [PATCH] Fix vineyard LLM cache - disable compute skipping if seq group list has more than one element - fix tensor slicing issue that caused incorrect generation and potential hanging during generation - refine query token size calculation and correct matched token size calculation Signed-off-by: Haiyang Shi --- vllm/worker/vineyard_llm_cache.py | 42 +++++++++++++++++++++++++------ 1 file changed, 34 insertions(+), 8 deletions(-) diff --git a/vllm/worker/vineyard_llm_cache.py b/vllm/worker/vineyard_llm_cache.py index 29649e5b6c0f4..3e127837af11e 100644 --- a/vllm/worker/vineyard_llm_cache.py +++ b/vllm/worker/vineyard_llm_cache.py @@ -97,6 +97,15 @@ def __init__( f"max_num_batched_tokens ({self.max_num_batched_tokens}) must" \ f"be a multiple of chunk_size ({self.chunk_size})" ) + + # Since we calculate num of tokens by the following way: + # query_context_len = context_len - context_len % self.chunk_size + # query_token_size = context_len + token_chunk_size - query_context_len + # where token_chunk_size could be as large as max_num_batched_tokens. + # Therefore, the max num of tokens we should support is max_num_batched_tokens + # + chunk_size in the case of context_len is not aligned with chunk_size. + self.max_num_batched_tokens += self.chunk_size + self.fetch_buffer, self.fetch_tensors = self._pinned_tensor_creator() self.cuda_buffer = self.fetch_buffer.cuda() self.enable_async_update = enable_async_update @@ -254,6 +263,7 @@ def prefetch_seq_kv_caches( seq_group_metadata: SequenceGroupMetadata, kv_caches: List[torch.Tensor], block_size: int, + is_comp_skippable: bool, ) -> Tuple[str, int]: from vllm._custom_ops import reshape_and_cache_flash if get_tensor_model_parallel_rank() == 0: @@ -275,6 +285,16 @@ def prefetch_seq_kv_caches( # alignment `context_len` to `self.chunk_size` query_context_len = context_len - context_len % self.chunk_size query_token_size = context_len + token_chunk_size - query_context_len + # align `query_token_size` to the next multiple of `self.chunk_size`. + # suppose `query_token_size` is 511 and `self.chunk_size` is 16, rather + # than using 496 to query, we use 512 in order to reduce the number of + # tokens to be recomputed. + query_token_size = ( + (query_token_size + self.chunk_size - 1) + // self.chunk_size + * self.chunk_size + ) + query_token_size = min(query_token_size, len(tokens) - query_context_len) query_prefix = tokens[:query_context_len] query_tokens = tokens[query_context_len:query_context_len + query_token_size] query_args = [ @@ -315,19 +335,22 @@ def prefetch_seq_kv_caches( duration = time.perf_counter() - start_time self.metrics.time_query.append(duration) self.metrics.normalized_time_query.append(duration/(len(query_tokens) + len(query_prefix))) - # If sampling is required, we need to leave one token unmatched + # shift + offset = context_len % self.chunk_size + matched -= offset + # If not comp skippable or sampling is required, we need to leave one token unmatched # to trigger the following sampling step in engine worker's workflow. - if seq_group_metadata is not None and seq_group_metadata.is_sampling_enabled: + if not is_comp_skippable or ( + seq_group_metadata is not None and seq_group_metadata.is_sampling_enabled + ): matched = min(matched, token_chunk_size - 1) + else: + matched = min(matched, token_chunk_size) # synchronized across tensor parallel ranks matched_tensor = torch.tensor([matched], dtype=torch.long, device='cuda') tensor_model_parallel_all_reduce(input_=matched_tensor, op=torch.distributed.ReduceOp.MIN) matched = matched_tensor[0].item() - # shift - offset = context_len % self.chunk_size - matched -= offset - if matched <= 0: return seq_id, 0 if get_tensor_model_parallel_rank() == 0: @@ -359,7 +382,7 @@ def prefetch_seq_kv_caches( # efficient than performing multiple smaller copy operations. This # approach reduces the number of transfers between CPU and GPU, # leading to faster overall performance. - buffer = self.cuda_buffer.copy_(self.fetch_buffer)[:, :, :matched] + buffer = self.cuda_buffer.copy_(self.fetch_buffer)[:, :, offset:offset + matched] copy_end.record() copy_end.synchronize() duration = copy_start.elapsed_time(copy_end) / 1000.0 @@ -405,9 +428,12 @@ def prefetch_kv_caches( ''' if block_size is None or kv_caches[0] is None: # profile run return {} + # skippable only if the seq_group_metadata_list contains a single element + is_comp_skippable = True if get_tensor_model_parallel_rank() == 0: prefill_requests = [] if seq_group_metadata_list is not None: + is_comp_skippable = True if len(seq_group_metadata_list) == 1 else False for seq_group_meta in seq_group_metadata_list: if seq_group_meta.is_prompt: prefill_requests.append(seq_group_meta) @@ -421,7 +447,7 @@ def prefetch_kv_caches( matched = {} for seq_group_meta in prefill_requests: seq_id, seq_matched = self.prefetch_seq_kv_caches( - seq_group_meta, kv_caches, block_size, + seq_group_meta, kv_caches, block_size, is_comp_skippable, ) matched[seq_id] = seq_matched if matched: