-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix vineyard LLM cache #18
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -145,6 +145,15 @@ def __init__( | |
f"max_num_batched_tokens ({self.max_num_batched_tokens}) must" \ | ||
f"be a multiple of chunk_size ({self.chunk_size})" | ||
) | ||
|
||
# Since we calculate num of tokens by the following way: | ||
# query_context_len = context_len - context_len % self.chunk_size | ||
# query_token_size = context_len + token_chunk_size - query_context_len | ||
# where token_chunk_size could be as large as max_num_batched_tokens. | ||
# Therefore, the max num of tokens we should support is max_num_batched_tokens | ||
# + chunk_size in the case of context_len is not aligned with chunk_size. | ||
self.max_num_batched_tokens += self.chunk_size | ||
|
||
self.fetch_buffer, self.fetch_tensors = self._pinned_tensor_creator() | ||
self.cuda_buffer = self.fetch_buffer.cuda() | ||
self.enable_async_update = enable_async_update | ||
|
@@ -308,6 +317,7 @@ def prefetch_seq_kv_caches( | |
seq_group_metadata: SequenceGroupMetadata, | ||
kv_caches: List[torch.Tensor], | ||
block_size: int, | ||
is_comp_skippable: bool, | ||
) -> Tuple[str, int]: | ||
from vllm._custom_ops import reshape_and_cache_flash | ||
if get_tensor_model_parallel_rank() == 0: | ||
|
@@ -329,6 +339,16 @@ def prefetch_seq_kv_caches( | |
# alignment `context_len` to `self.chunk_size` | ||
query_context_len = context_len - context_len % self.chunk_size | ||
query_token_size = context_len + token_chunk_size - query_context_len | ||
# align `query_token_size` to the next multiple of `self.chunk_size`. | ||
# suppose `query_token_size` is 511 and `self.chunk_size` is 16, rather | ||
# than using 496 to query, we use 512 in order to reduce the number of | ||
# tokens to be recomputed. | ||
query_token_size = ( | ||
(query_token_size + self.chunk_size - 1) | ||
// self.chunk_size | ||
* self.chunk_size | ||
) | ||
query_token_size = min(query_token_size, len(tokens) - query_context_len) | ||
query_prefix = tokens[:query_context_len] | ||
query_tokens = tokens[query_context_len:query_context_len + query_token_size] | ||
query_args = [ | ||
|
@@ -373,18 +393,22 @@ def prefetch_seq_kv_caches( | |
if self.metrics_enabled: | ||
duration = time.perf_counter() - start_time | ||
self.metrics.add_time_query(duration) | ||
# If sampling is required, we need to leave one token unmatched | ||
# shift | ||
offset = context_len % self.chunk_size | ||
matched -= offset | ||
# If not comp skippable or sampling is required, we need to leave one token unmatched | ||
# to trigger the following sampling step in engine worker's workflow. | ||
if seq_group_metadata is not None and seq_group_metadata.is_sampling_enabled: | ||
if not is_comp_skippable or ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the workaround you mentioned? We still can explore whether it's able to rebuild some seq group metadata later in this case, right? Or it's been fully evaluated and there's no chance to get performance gain here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can explore it later, but my gut feeling is that this wouldn't gain many benefits. |
||
seq_group_metadata is not None and seq_group_metadata.is_sampling_enabled | ||
): | ||
matched = min(matched, token_chunk_size - 1) | ||
else: | ||
matched = min(matched, token_chunk_size) | ||
# synchronized across tensor parallel ranks | ||
matched_tensor = torch.tensor([matched], dtype=torch.long, device='cuda') | ||
tensor_model_parallel_all_reduce(input_=matched_tensor, op=torch.distributed.ReduceOp.MIN) | ||
matched = matched_tensor[0].item() | ||
|
||
# shift | ||
offset = context_len % self.chunk_size | ||
matched -= offset | ||
if matched <= 0: | ||
return seq_id, 0 | ||
if get_tensor_model_parallel_rank() == 0: | ||
|
@@ -413,7 +437,7 @@ def prefetch_seq_kv_caches( | |
# efficient than performing multiple smaller copy operations. This | ||
# approach reduces the number of transfers between CPU and GPU, | ||
# leading to faster overall performance. | ||
buffer = self.cuda_buffer.copy_(self.fetch_buffer)[:, :, :matched] | ||
buffer = self.cuda_buffer.copy_(self.fetch_buffer)[:, :, offset:offset+matched] | ||
if self.metrics_enabled: | ||
copy_end.record() | ||
copy_end.synchronize() | ||
|
@@ -424,6 +448,7 @@ def prefetch_seq_kv_caches( | |
reshape_start = torch.cuda.Event(enable_timing=True) | ||
reshape_end = torch.cuda.Event(enable_timing=True) | ||
reshape_start.record() | ||
|
||
for j in range(self.layer): | ||
# use `reshape_and_cache_flash` rather than `copy_` as | ||
# the target kv cache slots is not contingous. | ||
|
@@ -459,9 +484,12 @@ def prefetch_kv_caches( | |
''' | ||
if block_size is None or kv_caches[0] is None: # profile run | ||
return {} | ||
# skippable only if the seq_group_metadata_list contains a single element | ||
is_comp_skippable = True | ||
if get_tensor_model_parallel_rank() == 0: | ||
prefill_requests = [] | ||
if seq_group_metadata_list is not None: | ||
is_comp_skippable = True if len(seq_group_metadata_list) == 1 else False | ||
for seq_group_meta in seq_group_metadata_list: | ||
if seq_group_meta.is_prompt: | ||
prefill_requests.append(seq_group_meta) | ||
|
@@ -475,7 +503,7 @@ def prefetch_kv_caches( | |
matched = {} | ||
for seq_group_meta in prefill_requests: | ||
seq_id, seq_matched = self.prefetch_seq_kv_caches( | ||
seq_group_meta, kv_caches, block_size, | ||
seq_group_meta, kv_caches, block_size, is_comp_skippable, | ||
) | ||
matched[seq_id] = seq_matched | ||
if matched: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this just performance improvement or bug fix?
for the benefits, Let's say we have 511 tokens, 16 chunk size.
in this case
query_context_len
= 496, what's the value ofquery_token_size
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how can query_token_size be 511?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
query_context_len
= 496, thenquery_token_size
= 16context_len
= 0, then thequery_token_size
calculated byquery_token_size = context_len + token_chunk_size - query_context_len
will be 511 since one slot is taken by decode andtoken_chunk_size
= 511