Skip to content

Commit

Permalink
feat: add support for chunked prefill + prefix caching
Browse files Browse the repository at this point in the history
  • Loading branch information
AlpinDale committed Dec 7, 2024
1 parent ef99a56 commit ace4bbe
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 26 deletions.
19 changes: 13 additions & 6 deletions aphrodite/processing/block_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,14 +680,20 @@ def access_all_blocks_in_seq(
for block in block_table:
block.last_accessed = access_time

def compute_full_blocks_in_seq(self, seq: Sequence):
def compute_full_blocks_in_seq(self, seq: Sequence, token_chunk_size: int):
if seq.seq_id not in self.block_tables:
return
max_full_block = seq.get_len() // self.block_size - 1

# When chunked prefill is enabled, the computed full blocks
# should be calculated based on the number of computed tokens.
max_computed_tokens = (seq.data.get_num_computed_tokens() +
token_chunk_size)
computed_full_blocks = max_computed_tokens // self.block_size

block_table = self.block_tables[seq.seq_id]
if max_full_block == -1:
if computed_full_blocks == 0:
return
for i in reversed(range(max_full_block)):
for i in reversed(range(computed_full_blocks)):
if block_table[i].computed:
break
block_table[i].computed = True
Expand Down Expand Up @@ -717,10 +723,11 @@ def get_common_computed_block_ids(
ids_list = [self.get_all_computed_blocks(seq) for seq in seqs]
return commonprefix([ids for ids in ids_list if ids != []])

def mark_blocks_as_computed(self, seq_group: SequenceGroup):
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
token_chunk_size: int):
if self.enable_caching:
for seq in seq_group.get_seqs():
self.compute_full_blocks_in_seq(seq)
self.compute_full_blocks_in_seq(seq, token_chunk_size)

def get_prefix_cache_hit_rate(self, device: Device) -> float:
if device == Device.GPU:
Expand Down
3 changes: 2 additions & 1 deletion aphrodite/processing/block_manager_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,8 @@ def access_all_blocks_in_seq(self, seq: Sequence, now: float):
self._last_access_blocks_tracker.update_last_access(
seq.seq_id, now)

def mark_blocks_as_computed(self, seq_group: SequenceGroup):
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
token_chunk_size: int):
# The only need for mark block as computed is for prefix caching,
# while currently we could determine whether one block is computed
# or not by check whether it has content hash.
Expand Down
3 changes: 2 additions & 1 deletion aphrodite/processing/placeholder_block_space_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def get_common_computed_block_ids(self,
seq_group: SequenceGroup) -> List[int]:
return None # type: ignore

def mark_blocks_as_computed(self, seq_group: SequenceGroup):
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
token_chunk_size: int):
pass

def get_prefix_cache_hit_rate(self, device: Device) -> float:
Expand Down
30 changes: 24 additions & 6 deletions aphrodite/processing/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,7 +1152,8 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
# will crash the Aphrodite instance / will not retry.
for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
self.block_manager.mark_blocks_as_computed(
scheduled_seq_group.seq_group)
scheduled_seq_group.seq_group,
scheduled_seq_group.token_chunk_size)

return seq_group_metadata_list, scheduler_outputs

Expand Down Expand Up @@ -1344,10 +1345,27 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup,
for seq in seqs:
num_new_tokens += seq.get_num_new_tokens()
assert num_new_tokens > 0
# Chunk if a running request cannot fit in.
# If number of seq > 1, it means it is doing beam search in a
# decode phase. Do not chunk in that case.
# Chunk if a running request cannot fit in the given budget.
# If number of seq > 1, it means it is doing beam search
# in a decode phase. Do not chunk.
if enable_chunking and len(seqs) == 1:
num_new_tokens = min(num_new_tokens,
budget.remaining_token_budget())
remaining_token_budget = budget.remaining_token_budget()
if self.cache_config.enable_prefix_caching:
# When prefix caching is enabled, we always allocate
# the number of new tokens that is dividable by the block size
# to avoid partial block matching.
block_size = self.cache_config.block_size
reminder = budget.token_budget % block_size
if reminder != 0:
raise ValueError("When enabling chunked prefill and "
"prefix caching, max_num_batched_tokens "
"(chunk size) must be dividable by "
"block size, but got chunk_size "
f"({budget.token_budget}) % block_size "
f"({block_size}) = {reminder}")
if remaining_token_budget < num_new_tokens:
num_new_tokens = (remaining_token_budget //
block_size) * block_size
else:
num_new_tokens = min(num_new_tokens, remaining_token_budget)
return num_new_tokens
49 changes: 37 additions & 12 deletions aphrodite/task_handler/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,23 +499,48 @@ def _compute_for_prefix_cache_hit(
and self.sliding_window is None
and inter_data.is_prompt)
inter_data.prefix_cache_hit = prefix_cache_hit
if self.chunked_prefill_enabled and prefix_cache_hit:
raise RuntimeError(
"chunked prefill cannot be used with prefix caching now.")

# If prefix cache is hit, advance context length to bypass
# hit blocks. Accordingly, input tokens, position and query length
# have to be updated.
if prefix_cache_hit:
assert computed_block_nums is not None
context_len = len(computed_block_nums) * self.block_size

if not prefix_cache_hit:
return

assert computed_block_nums is not None
# The cache hit prompt tokens in this sequence. Note that
# this may be larger than the sequence length if chunked
# prefill is enabled.
prefix_cache_len = len(computed_block_nums) * self.block_size
# The number of so far computed prompt tokens in this sequence.
context_len = inter_data.context_lens[seq_idx]
# The total number of prompt tokens in this sequence.
# When chunked prefill is enabled, this is the token number of
# computed chunks + current chunk.
seq_len = inter_data.seq_lens[seq_idx]
if prefix_cache_len <= context_len:
# We already passed the cache hit region,
# so do normal computation.
pass
elif context_len < prefix_cache_len < seq_len:
# Partial hit. Compute the missing part.
uncomputed_start = prefix_cache_len - context_len
inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
seq_idx][context_len:]
seq_idx][uncomputed_start:]
inter_data.input_positions[seq_idx] = inter_data.input_positions[
seq_idx][context_len:]
seq_idx][uncomputed_start:]
context_len = prefix_cache_len

inter_data.context_lens[seq_idx] = context_len
inter_data.query_lens[
seq_idx] = inter_data.seq_lens[seq_idx] - context_len
elif seq_len <= prefix_cache_len:
# Full hit. Only compute the last token to avoid
# erroneous behavior. FIXME: Ideally we should directly
# mark all tokens as computed in the scheduler and do not
# schedule this sequence, so this case should not happen.
inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
seq_idx][-1:]
inter_data.input_positions[seq_idx] = inter_data.input_positions[
seq_idx][-1:]
inter_data.query_lens[seq_idx] = 1
inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1

def _compute_for_sliding_window(self, inter_data: InterDataForSeqGroup,
seq_idx: int,
Expand Down

0 comments on commit ace4bbe

Please sign in to comment.