Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix vineyard LLM cache #18

Merged
merged 1 commit into from
Dec 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions vllm/worker/vineyard_llm_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,15 @@ def __init__(
f"max_num_batched_tokens ({self.max_num_batched_tokens}) must" \
f"be a multiple of chunk_size ({self.chunk_size})"
)

# Since we calculate num of tokens by the following way:
# query_context_len = context_len - context_len % self.chunk_size
# query_token_size = context_len + token_chunk_size - query_context_len
# where token_chunk_size could be as large as max_num_batched_tokens.
# Therefore, the max num of tokens we should support is max_num_batched_tokens
# + chunk_size in the case of context_len is not aligned with chunk_size.
self.max_num_batched_tokens += self.chunk_size

self.fetch_buffer, self.fetch_tensors = self._pinned_tensor_creator()
self.cuda_buffer = self.fetch_buffer.cuda()
self.enable_async_update = enable_async_update
Expand Down Expand Up @@ -308,6 +317,7 @@ def prefetch_seq_kv_caches(
seq_group_metadata: SequenceGroupMetadata,
kv_caches: List[torch.Tensor],
block_size: int,
is_comp_skippable: bool,
) -> Tuple[str, int]:
from vllm._custom_ops import reshape_and_cache_flash
if get_tensor_model_parallel_rank() == 0:
Expand All @@ -329,6 +339,16 @@ def prefetch_seq_kv_caches(
# alignment `context_len` to `self.chunk_size`
query_context_len = context_len - context_len % self.chunk_size
query_token_size = context_len + token_chunk_size - query_context_len
# align `query_token_size` to the next multiple of `self.chunk_size`.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this just performance improvement or bug fix?

for the benefits, Let's say we have 511 tokens, 16 chunk size.
in this case query_context_len = 496, what's the value of query_token_size?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how can query_token_size be 511?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • a performance improvement
  • if query_context_len = 496, then query_token_size = 16
  • when a decode is scheduled together with the prefill, suppose context_len = 0, then the query_token_size calculated by query_token_size = context_len + token_chunk_size - query_context_len will be 511 since one slot is taken by decode and token_chunk_size = 511

# suppose `query_token_size` is 511 and `self.chunk_size` is 16, rather
# than using 496 to query, we use 512 in order to reduce the number of
# tokens to be recomputed.
query_token_size = (
(query_token_size + self.chunk_size - 1)
// self.chunk_size
* self.chunk_size
)
query_token_size = min(query_token_size, len(tokens) - query_context_len)
query_prefix = tokens[:query_context_len]
query_tokens = tokens[query_context_len:query_context_len + query_token_size]
query_args = [
Expand Down Expand Up @@ -373,18 +393,22 @@ def prefetch_seq_kv_caches(
if self.metrics_enabled:
duration = time.perf_counter() - start_time
self.metrics.add_time_query(duration)
# If sampling is required, we need to leave one token unmatched
# shift
offset = context_len % self.chunk_size
matched -= offset
# If not comp skippable or sampling is required, we need to leave one token unmatched
# to trigger the following sampling step in engine worker's workflow.
if seq_group_metadata is not None and seq_group_metadata.is_sampling_enabled:
if not is_comp_skippable or (
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the workaround you mentioned? We still can explore whether it's able to rebuild some seq group metadata later in this case, right? Or it's been fully evaluated and there's no chance to get performance gain here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can explore it later, but my gut feeling is that this wouldn't gain many benefits.

seq_group_metadata is not None and seq_group_metadata.is_sampling_enabled
):
matched = min(matched, token_chunk_size - 1)
else:
matched = min(matched, token_chunk_size)
# synchronized across tensor parallel ranks
matched_tensor = torch.tensor([matched], dtype=torch.long, device='cuda')
tensor_model_parallel_all_reduce(input_=matched_tensor, op=torch.distributed.ReduceOp.MIN)
matched = matched_tensor[0].item()

# shift
offset = context_len % self.chunk_size
matched -= offset
if matched <= 0:
return seq_id, 0
if get_tensor_model_parallel_rank() == 0:
Expand Down Expand Up @@ -413,7 +437,7 @@ def prefetch_seq_kv_caches(
# efficient than performing multiple smaller copy operations. This
# approach reduces the number of transfers between CPU and GPU,
# leading to faster overall performance.
buffer = self.cuda_buffer.copy_(self.fetch_buffer)[:, :, :matched]
buffer = self.cuda_buffer.copy_(self.fetch_buffer)[:, :, offset:offset+matched]
if self.metrics_enabled:
copy_end.record()
copy_end.synchronize()
Expand All @@ -424,6 +448,7 @@ def prefetch_seq_kv_caches(
reshape_start = torch.cuda.Event(enable_timing=True)
reshape_end = torch.cuda.Event(enable_timing=True)
reshape_start.record()

for j in range(self.layer):
# use `reshape_and_cache_flash` rather than `copy_` as
# the target kv cache slots is not contingous.
Expand Down Expand Up @@ -459,9 +484,12 @@ def prefetch_kv_caches(
'''
if block_size is None or kv_caches[0] is None: # profile run
return {}
# skippable only if the seq_group_metadata_list contains a single element
is_comp_skippable = True
if get_tensor_model_parallel_rank() == 0:
prefill_requests = []
if seq_group_metadata_list is not None:
is_comp_skippable = True if len(seq_group_metadata_list) == 1 else False
for seq_group_meta in seq_group_metadata_list:
if seq_group_meta.is_prompt:
prefill_requests.append(seq_group_meta)
Expand All @@ -475,7 +503,7 @@ def prefetch_kv_caches(
matched = {}
for seq_group_meta in prefill_requests:
seq_id, seq_matched = self.prefetch_seq_kv_caches(
seq_group_meta, kv_caches, block_size,
seq_group_meta, kv_caches, block_size, is_comp_skippable,
)
matched[seq_id] = seq_matched
if matched:
Expand Down