Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "feat: add support for chunked prefill + prefix caching (#871)" #903

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 6 additions & 13 deletions aphrodite/processing/block_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,20 +680,14 @@ def access_all_blocks_in_seq(
for block in block_table:
block.last_accessed = access_time

def compute_full_blocks_in_seq(self, seq: Sequence, token_chunk_size: int):
def compute_full_blocks_in_seq(self, seq: Sequence):
if seq.seq_id not in self.block_tables:
return

# When chunked prefill is enabled, the computed full blocks
# should be calculated based on the number of computed tokens.
max_computed_tokens = (seq.data.get_num_computed_tokens() +
token_chunk_size)
computed_full_blocks = max_computed_tokens // self.block_size

max_full_block = seq.get_len() // self.block_size - 1
block_table = self.block_tables[seq.seq_id]
if computed_full_blocks == 0:
if max_full_block == -1:
return
for i in reversed(range(computed_full_blocks)):
for i in reversed(range(max_full_block)):
if block_table[i].computed:
break
block_table[i].computed = True
Expand Down Expand Up @@ -723,11 +717,10 @@ def get_common_computed_block_ids(
ids_list = [self.get_all_computed_blocks(seq) for seq in seqs]
return commonprefix([ids for ids in ids_list if ids != []])

def mark_blocks_as_computed(self, seq_group: SequenceGroup,
token_chunk_size: int):
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
if self.enable_caching:
for seq in seq_group.get_seqs():
self.compute_full_blocks_in_seq(seq, token_chunk_size)
self.compute_full_blocks_in_seq(seq)

def get_prefix_cache_hit_rate(self, device: Device) -> float:
if device == Device.GPU:
Expand Down
3 changes: 1 addition & 2 deletions aphrodite/processing/block_manager_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,7 @@ def access_all_blocks_in_seq(self, seq: Sequence, now: float):
self._last_access_blocks_tracker.update_last_access(
seq.seq_id, now)

def mark_blocks_as_computed(self, seq_group: SequenceGroup,
token_chunk_size: int):
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
# The only need for mark block as computed is for prefix caching,
# while currently we could determine whether one block is computed
# or not by check whether it has content hash.
Expand Down
3 changes: 1 addition & 2 deletions aphrodite/processing/placeholder_block_space_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ def get_common_computed_block_ids(self,
seq_group: SequenceGroup) -> List[int]:
return None # type: ignore

def mark_blocks_as_computed(self, seq_group: SequenceGroup,
token_chunk_size: int):
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
pass

def get_prefix_cache_hit_rate(self, device: Device) -> float:
Expand Down
30 changes: 6 additions & 24 deletions aphrodite/processing/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,8 +1152,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
# will crash the Aphrodite instance / will not retry.
for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
self.block_manager.mark_blocks_as_computed(
scheduled_seq_group.seq_group,
scheduled_seq_group.token_chunk_size)
scheduled_seq_group.seq_group)

return seq_group_metadata_list, scheduler_outputs

Expand Down Expand Up @@ -1345,27 +1344,10 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup,
for seq in seqs:
num_new_tokens += seq.get_num_new_tokens()
assert num_new_tokens > 0
# Chunk if a running request cannot fit in the given budget.
# If number of seq > 1, it means it is doing beam search
# in a decode phase. Do not chunk.
# Chunk if a running request cannot fit in.
# If number of seq > 1, it means it is doing beam search in a
# decode phase. Do not chunk in that case.
if enable_chunking and len(seqs) == 1:
remaining_token_budget = budget.remaining_token_budget()
if self.cache_config.enable_prefix_caching:
# When prefix caching is enabled, we always allocate
# the number of new tokens that is dividable by the block size
# to avoid partial block matching.
block_size = self.cache_config.block_size
reminder = budget.token_budget % block_size
if reminder != 0:
raise ValueError("When enabling chunked prefill and "
"prefix caching, max_num_batched_tokens "
"(chunk size) must be dividable by "
"block size, but got chunk_size "
f"({budget.token_budget}) % block_size "
f"({block_size}) = {reminder}")
if remaining_token_budget < num_new_tokens:
num_new_tokens = (remaining_token_budget //
block_size) * block_size
else:
num_new_tokens = min(num_new_tokens, remaining_token_budget)
num_new_tokens = min(num_new_tokens,
budget.remaining_token_budget())
return num_new_tokens
49 changes: 12 additions & 37 deletions aphrodite/task_handler/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,48 +501,23 @@ def _compute_for_prefix_cache_hit(
and self.sliding_window is None
and inter_data.is_prompt)
inter_data.prefix_cache_hit = prefix_cache_hit

if not prefix_cache_hit:
return

assert computed_block_nums is not None
# The cache hit prompt tokens in this sequence. Note that
# this may be larger than the sequence length if chunked
# prefill is enabled.
prefix_cache_len = len(computed_block_nums) * self.block_size
# The number of so far computed prompt tokens in this sequence.
context_len = inter_data.context_lens[seq_idx]
# The total number of prompt tokens in this sequence.
# When chunked prefill is enabled, this is the token number of
# computed chunks + current chunk.
seq_len = inter_data.seq_lens[seq_idx]
if prefix_cache_len <= context_len:
# We already passed the cache hit region,
# so do normal computation.
pass
elif context_len < prefix_cache_len < seq_len:
# Partial hit. Compute the missing part.
uncomputed_start = prefix_cache_len - context_len
if self.chunked_prefill_enabled and prefix_cache_hit:
raise RuntimeError(
"chunked prefill cannot be used with prefix caching now.")

# If prefix cache is hit, advance context length to bypass
# hit blocks. Accordingly, input tokens, position and query length
# have to be updated.
if prefix_cache_hit:
assert computed_block_nums is not None
context_len = len(computed_block_nums) * self.block_size
inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
seq_idx][uncomputed_start:]
seq_idx][context_len:]
inter_data.input_positions[seq_idx] = inter_data.input_positions[
seq_idx][uncomputed_start:]
context_len = prefix_cache_len

seq_idx][context_len:]
inter_data.context_lens[seq_idx] = context_len
inter_data.query_lens[
seq_idx] = inter_data.seq_lens[seq_idx] - context_len
elif seq_len <= prefix_cache_len:
# Full hit. Only compute the last token to avoid
# erroneous behavior. FIXME: Ideally we should directly
# mark all tokens as computed in the scheduler and do not
# schedule this sequence, so this case should not happen.
inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
seq_idx][-1:]
inter_data.input_positions[seq_idx] = inter_data.input_positions[
seq_idx][-1:]
inter_data.query_lens[seq_idx] = 1
inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1

def _compute_for_sliding_window(self, inter_data: InterDataForSeqGroup,
seq_idx: int,
Expand Down
66 changes: 0 additions & 66 deletions tests/basic_correctness/test_chunked_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

Run `pytest tests/models/test_chunked_prefill.py`.
"""
from contextlib import nullcontext

import pytest

Expand Down Expand Up @@ -152,68 +151,3 @@ def test_models_with_fp8_kv_cache(
name_0="no_chunked_prefill",
name_1="chunked_prefill",
)


@pytest.mark.parametrize("max_tokens", [16])
@pytest.mark.parametrize("enforce_eager", [False])
@pytest.mark.parametrize("chunk_size", [30, 32])
@pytest.mark.parametrize("use_v2_block_manager", [False, True])
# NOTE: Increasing this in this suite will fail CI because we currently cannot
# reset distributed env properly. Use a value > 1 just when you test.
@pytest.mark.parametrize("tensor_parallel_size", [1])
def test_with_prefix_caching(
aphrodite_runner,
max_tokens: int,
enforce_eager: bool,
chunk_size: int,
use_v2_block_manager: bool,
tensor_parallel_size: int,
) -> None:
"""
Checks exact match decode with and without prefix caching
with chunked prefill enabled.
"""
model = "meta-llama/Llama-2-7b-chat-hf"
# The common prompt has 142 tokens with Llama-2 tokenizer.
common_prompt = "You are a helpful AI assistant " * 20
unique_prompts = [
"Question", # Warmup
"Question", # Fully cached
"Another question", # Partial cached
]
full_prompts = [f"{common_prompt}\n{p}" for p in unique_prompts]

max_num_batched_tokens = max_num_seqs = chunk_size
outputs = {} # type: ignore
check_result = True
for enable in (True, False):
with aphrodite_runner(
model,
dtype="half",
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=True,
enable_prefix_caching=enable,
tensor_parallel_size=tensor_parallel_size,
use_v2_block_manager=use_v2_block_manager,
enforce_eager=enforce_eager,
max_num_seqs=max_num_seqs,
) as aphrodite_model:
# It should fail when prefix caching is enable and chunk
# size is not a multiple of block size (16).
should_fail = chunk_size % 16 != 0 and enable
check_result &= not should_fail
outputs[enable] = []
# Send the request one-by-one to ensure the cache is populated.
with pytest.raises(ValueError) if should_fail else nullcontext():
for prompt in full_prompts:
outputs[enable] += aphrodite_model.generate_greedy(
[prompt], max_tokens)

# Check results only if we did not expect a failure.
if check_result:
check_outputs_equal(
outputs_0_lst=outputs[False],
outputs_1_lst=outputs[True],
name_0="w/o prefix caching",
name_1="with prefix caching",
)
40 changes: 0 additions & 40 deletions tests/core/test_block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,43 +596,3 @@ def test_sliding_window_multi_seq():

# assert all blocks are free now
assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks


def test_mark_blocks_as_computed_with_prefix_cache_and_chunked_prefill():
"""When prefix cache and chunked prefill are enabled, the block manager
should only mark a chunk of blocks as computed instead of all blocks.
"""

block_size = 4
num_cpu_blocks = 0
num_gpu_blocks = 16
block_manager = BlockSpaceManagerV1(block_size,
num_gpu_blocks,
num_cpu_blocks,
watermark=0,
enable_caching=True)

# Set prompt size to have num_gpu_blocks - 1 full blocks.
prompt_length = block_size * num_gpu_blocks - 1

# Allocate (reserve) all blocks.
_, seq_group = create_dummy_prompt("0",
prompt_length,
block_size=block_size)
block_manager.allocate(seq_group)
assert seq_group.seqs[0].n_blocks == num_gpu_blocks

# 1st chunk: Compute 2 and half blocks. Should mark 2 blocks as computed.
token_chunk_size = int(block_size * 2.5)
block_manager.mark_blocks_as_computed(seq_group, token_chunk_size)
computed_blocks = block_manager.get_all_computed_blocks(seq_group.seqs[0])
assert len(computed_blocks) == 2

# Actual computed tokens.
seq_group.seqs[0].data.update_num_computed_tokens(token_chunk_size)

# 2nd chunk: Complete 3rd block and additional 4 blocks.
token_chunk_size = int(block_size * 4.5)
block_manager.mark_blocks_as_computed(seq_group, token_chunk_size)
computed_blocks = block_manager.get_all_computed_blocks(seq_group.seqs[0])
assert len(computed_blocks) == 7
39 changes: 0 additions & 39 deletions tests/core/test_chunked_prefill_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,42 +581,3 @@ def test_chunked_prefill_max_seqs():
assert len(get_sequence_groups(out)) == max_seqs
assert not running[0].is_prefill()
assert not running[1].is_prefill()


def test_perfix_caching():
"""Verify allocating full blocks when prefix caching is enabled."""
block_size = 4
max_seqs = 10
max_model_len = 80
max_num_batched_tokens = 64
scheduler_config = SchedulerConfig(max_num_batched_tokens,
max_seqs,
max_model_len,
enable_chunked_prefill=True)
cache_config = CacheConfig(block_size,
1.0,
1,
"auto",
enable_prefix_caching=True)
cache_config.num_cpu_blocks = 0
cache_config.num_gpu_blocks = 32
scheduler = Scheduler(scheduler_config, cache_config, None)
running: List[SequenceGroup] = []

# Add seq groups to scheduler.
for i in range(2):
_, seq_group = create_dummy_prompt(str(i),
block_size=block_size,
prompt_length=50)
scheduler.add_seq_group(seq_group)
running.append(seq_group)

seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert set(get_sequence_groups(out)) == set(running)
assert seq_group_meta[0].token_chunk_size == 50
# Verify it is chunked. Note that although the budget is 64-50=14,
# we only allocate full blocks for prefix caching, so only 4*(14//4)=12
# tokens are allocated.
assert seq_group_meta[1].token_chunk_size == 12
assert out.num_prefill_groups == 2
assert out.num_batched_tokens == 62
Loading