Skip to content

Commit

Permalink
Fix vineyard LLM cache (#18)
Browse files Browse the repository at this point in the history
- disable compute skipping if seq group list has more than one element
- fix tensor slicing issue that caused incorrect generation and
  potential hanging during generation
- refine query token size calculation and correct matched token size
  calculation

Signed-off-by: Haiyang Shi <[email protected]>
Co-authored-by: Haiyang Shi <[email protected]>
  • Loading branch information
DwyaneShi and Haiyang Shi authored Dec 26, 2024
1 parent 8d34fa4 commit eb83e1d
Showing 1 changed file with 35 additions and 7 deletions.
42 changes: 35 additions & 7 deletions vllm/worker/vineyard_llm_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,15 @@ def __init__(
f"max_num_batched_tokens ({self.max_num_batched_tokens}) must" \
f"be a multiple of chunk_size ({self.chunk_size})"
)

# Since we calculate num of tokens by the following way:
# query_context_len = context_len - context_len % self.chunk_size
# query_token_size = context_len + token_chunk_size - query_context_len
# where token_chunk_size could be as large as max_num_batched_tokens.
# Therefore, the max num of tokens we should support is max_num_batched_tokens
# + chunk_size in the case of context_len is not aligned with chunk_size.
self.max_num_batched_tokens += self.chunk_size

self.fetch_buffer, self.fetch_tensors = self._pinned_tensor_creator()
self.cuda_buffer = self.fetch_buffer.cuda()
self.enable_async_update = enable_async_update
Expand Down Expand Up @@ -308,6 +317,7 @@ def prefetch_seq_kv_caches(
seq_group_metadata: SequenceGroupMetadata,
kv_caches: List[torch.Tensor],
block_size: int,
is_comp_skippable: bool,
) -> Tuple[str, int]:
from vllm._custom_ops import reshape_and_cache_flash
if get_tensor_model_parallel_rank() == 0:
Expand All @@ -329,6 +339,16 @@ def prefetch_seq_kv_caches(
# alignment `context_len` to `self.chunk_size`
query_context_len = context_len - context_len % self.chunk_size
query_token_size = context_len + token_chunk_size - query_context_len
# align `query_token_size` to the next multiple of `self.chunk_size`.
# suppose `query_token_size` is 511 and `self.chunk_size` is 16, rather
# than using 496 to query, we use 512 in order to reduce the number of
# tokens to be recomputed.
query_token_size = (
(query_token_size + self.chunk_size - 1)
// self.chunk_size
* self.chunk_size
)
query_token_size = min(query_token_size, len(tokens) - query_context_len)
query_prefix = tokens[:query_context_len]
query_tokens = tokens[query_context_len:query_context_len + query_token_size]
query_args = [
Expand Down Expand Up @@ -373,18 +393,22 @@ def prefetch_seq_kv_caches(
if self.metrics_enabled:
duration = time.perf_counter() - start_time
self.metrics.add_time_query(duration)
# If sampling is required, we need to leave one token unmatched
# shift
offset = context_len % self.chunk_size
matched -= offset
# If not comp skippable or sampling is required, we need to leave one token unmatched
# to trigger the following sampling step in engine worker's workflow.
if seq_group_metadata is not None and seq_group_metadata.is_sampling_enabled:
if not is_comp_skippable or (
seq_group_metadata is not None and seq_group_metadata.is_sampling_enabled
):
matched = min(matched, token_chunk_size - 1)
else:
matched = min(matched, token_chunk_size)
# synchronized across tensor parallel ranks
matched_tensor = torch.tensor([matched], dtype=torch.long, device='cuda')
tensor_model_parallel_all_reduce(input_=matched_tensor, op=torch.distributed.ReduceOp.MIN)
matched = matched_tensor[0].item()

# shift
offset = context_len % self.chunk_size
matched -= offset
if matched <= 0:
return seq_id, 0
if get_tensor_model_parallel_rank() == 0:
Expand Down Expand Up @@ -413,7 +437,7 @@ def prefetch_seq_kv_caches(
# efficient than performing multiple smaller copy operations. This
# approach reduces the number of transfers between CPU and GPU,
# leading to faster overall performance.
buffer = self.cuda_buffer.copy_(self.fetch_buffer)[:, :, :matched]
buffer = self.cuda_buffer.copy_(self.fetch_buffer)[:, :, offset:offset+matched]
if self.metrics_enabled:
copy_end.record()
copy_end.synchronize()
Expand All @@ -424,6 +448,7 @@ def prefetch_seq_kv_caches(
reshape_start = torch.cuda.Event(enable_timing=True)
reshape_end = torch.cuda.Event(enable_timing=True)
reshape_start.record()

for j in range(self.layer):
# use `reshape_and_cache_flash` rather than `copy_` as
# the target kv cache slots is not contingous.
Expand Down Expand Up @@ -459,9 +484,12 @@ def prefetch_kv_caches(
'''
if block_size is None or kv_caches[0] is None: # profile run
return {}
# skippable only if the seq_group_metadata_list contains a single element
is_comp_skippable = True
if get_tensor_model_parallel_rank() == 0:
prefill_requests = []
if seq_group_metadata_list is not None:
is_comp_skippable = True if len(seq_group_metadata_list) == 1 else False
for seq_group_meta in seq_group_metadata_list:
if seq_group_meta.is_prompt:
prefill_requests.append(seq_group_meta)
Expand All @@ -475,7 +503,7 @@ def prefetch_kv_caches(
matched = {}
for seq_group_meta in prefill_requests:
seq_id, seq_matched = self.prefetch_seq_kv_caches(
seq_group_meta, kv_caches, block_size,
seq_group_meta, kv_caches, block_size, is_comp_skippable,
)
matched[seq_id] = seq_matched
if matched:
Expand Down

0 comments on commit eb83e1d

Please sign in to comment.