Skip to content

Commit

Permalink
support prefix cache
Browse files Browse the repository at this point in the history
Signed-off-by: jiang1.li <[email protected]>
  • Loading branch information
bigPYJ1151 committed Nov 15, 2024
1 parent 5980981 commit c45b967
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
8 changes: 5 additions & 3 deletions vllm/executor/cpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,12 @@ def _init_executor(self) -> None:
self.parallel_config = _verify_and_get_parallel_config(
self.parallel_config)

if (self.scheduler_config.chunked_prefill_enabled
if ((self.scheduler_config.chunked_prefill_enabled
or self.cache_config.enable_prefix_caching)
and self.model_config.dtype == torch.half):
logger.warning("Chunked-prefill on the CPU backend only does not"
" support fp16 for now, cast to bf16.")
logger.warning("chunked-prefill and prefix-cache on the CPU "
"backend does not support fp16 for now,"
" cast to bf16.")
self.model_config.dtype = torch.bfloat16

# Multiprocessing-based executor does not support multi-node setting.
Expand Down
26 changes: 25 additions & 1 deletion vllm/worker/cpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ def __init__(self,
super().__init__()
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
self.runner = runner
self.chunked_prefill = runner.scheduler_config.chunked_prefill_enabled
self.chunked_prefill = (runner.scheduler_config.chunked_prefill_enabled
or runner.cache_config.enable_prefix_caching)
self.model_input_cls = self.runner._model_input_cls
self.attn_backend = self.runner.attn_backend
self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
Expand Down Expand Up @@ -284,6 +285,29 @@ def _compute_input_tokens(self, data: ModelInputData,
context_len = seq_data.get_num_computed_tokens()
if is_prompt:
seq_len = min(seq_len, context_len + token_chunk_size)

# For prefix caching
prefix_cache_block_num = len(
seq_group_metadata.computed_block_nums)
if prefix_cache_block_num > 0:
prefix_cache_len = (prefix_cache_block_num *
self.runner.block_size)
if prefix_cache_len <= context_len:
# We already passed the cache hit region,
# so do normal computation.
pass
elif context_len < prefix_cache_len < seq_len:
# Partial hit. Compute the missing part.
context_len = prefix_cache_len
token_chunk_size = seq_len - context_len
elif seq_len <= prefix_cache_len:
# Full hit. Only compute the last token to avoid
# erroneous behavior. FIXME: Ideally we should directly
# mark all tokens as computed in the scheduler and do not
# schedule this sequence, so this case should not happen.
context_len = seq_len - 1
token_chunk_size = 1

tokens = seq_data.get_token_ids()
tokens = tokens[context_len:seq_len]
token_positions = range(context_len, seq_len)
Expand Down

0 comments on commit c45b967

Please sign in to comment.