Skip to content

Commit

Permalink
Add Automatic Prefix Caching (vllm-project#2762)
Browse files Browse the repository at this point in the history
Co-authored-by: ElizaWszola <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
  • Loading branch information
3 people authored Mar 2, 2024
1 parent baee28c commit ce4f5a2
Show file tree
Hide file tree
Showing 18 changed files with 618 additions and 292 deletions.
30 changes: 16 additions & 14 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,21 +73,21 @@ def run_vllm(
enforce_eager: bool,
kv_cache_dtype: str,
device: str,
enable_prefix_caching: bool,
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(
model=model,
tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code,
dtype=dtype,
max_model_len=max_model_len,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
device=device,
)
llm = LLM(model=model,
tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code,
dtype=dtype,
max_model_len=max_model_len,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
device=device,
enable_prefix_caching=enable_prefix_caching)

# Add the requests to the engine.
for prompt, _, output_len in requests:
Expand Down Expand Up @@ -211,7 +211,8 @@ def main(args: argparse.Namespace):
args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype,
args.max_model_len, args.enforce_eager,
args.kv_cache_dtype, args.device)
args.kv_cache_dtype, args.device,
args.enable_prefix_caching)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
Expand Down Expand Up @@ -302,6 +303,7 @@ def main(args: argparse.Namespace):
default="cuda",
choices=["cuda"],
help='device type for vLLM execution, supporting CUDA only currently.')
parser.add_argument("--enable_prefix_caching", action='store_true')
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
Expand Down
4 changes: 4 additions & 0 deletions docs/source/models/engine_args.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ Below, you can find an explanation of every engine argument for vLLM:

Token block size for contiguous chunks of tokens.

.. option:: --enable-prefix-caching

Enables automatic prefix caching

.. option:: --seed <seed>

Random seed for operations.
Expand Down
11 changes: 2 additions & 9 deletions examples/offline_inference_with_prefix.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,13 @@

print("-" * 80)

# -1 since the last token can change when concatenating prompts.
prefix_pos = len(llm.llm_engine.tokenizer.encode(prefix)) - 1

# The llm.generate call will batch all prompts and send the batch at once if resources allow.
# The prefix will only be cached after the first batch is processed, so we need to call generate once
# to calculate the prefix and cache it.
outputs = llm.generate(generating_prompts[0],
sampling_params,
prefix_pos=[prefix_pos])
outputs = llm.generate(generating_prompts[0], sampling_params)

# Subsequent batches can leverage the cached prefix
outputs = llm.generate(generating_prompts,
sampling_params,
prefix_pos=[prefix_pos] * len(generating_prompts))
outputs = llm.generate(generating_prompts, sampling_params)

# Print the outputs. You should see the same outputs as before
for output in outputs:
Expand Down
103 changes: 69 additions & 34 deletions tests/prefix_caching/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,73 @@
"""
import pytest

from vllm import LLM, SamplingParams

prefix = (
"You are an expert school principal, skilled in effectively managing "
"faculty and staff. Draft 10-15 questions for a potential first grade "
"Head Teacher for my K-12, all-girls', independent school that emphasizes "
"community, joyful discovery, and life-long learning. The candidate is "
"coming in for a first-round panel interview for a 8th grade Math "
"teaching role. They have 5 years of previous teaching experience "
"as an assistant teacher at a co-ed, public school with experience "
"in middle school math teaching. Based on these information, fulfill "
"the following paragraph: ")


@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("max_tokens", [16])
def test_prefix_caching(
example_prompts,
model: str,
max_tokens: int,
from vllm.core.block_manager import BlockAllocator
from vllm.utils import Device


@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_blocks", [16])
def test_block_allocator(
block_size: int,
num_blocks: int,
):
llm = LLM(model=model)
# -1 since the last token can change when concatenating prompts.
prefix_pos = len(llm.llm_engine.tokenizer.encode(prefix)) - 1
prompts = [prefix + prompt for prompt in example_prompts]
sampling_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs_without_prefix = llm.generate(prompts, sampling_params)
outputs_with_prefix = llm.generate(prompts,
sampling_params,
prefix_pos=[prefix_pos] * len(prompts))
for output_without_prefix, output_with_prefix in zip(
outputs_without_prefix, outputs_with_prefix):
assert (output_without_prefix.outputs[0].token_ids ==
output_with_prefix.outputs[0].token_ids)
assert len(llm.llm_engine.scheduler.prefix_pool.prefixes) == 1
block_hash = 1
block_allocator = BlockAllocator(Device.CPU,
block_size,
num_blocks,
enable_caching=True)

# Allocate two PysicalTokenBlocks with the same hash and check that they are the same PhysicalTokenBlock
first_block = block_allocator.allocate(block_hash, 0)
second_block = block_allocator.allocate(block_hash, 0)
assert (first_block == second_block)
assert (second_block.ref_count == 2)

# Free the first_block and confirm that the ref_count is correctly decremented on the second block
block_allocator.free(first_block)
assert (second_block.ref_count == 1)

# Free the second block
block_allocator.free(second_block)

# Reallocate the first block and confirm that, even after the block had its ref_count go to 0, we still get the same block back
first_block = block_allocator.allocate(block_hash, 0)
assert (first_block == second_block)
assert (first_block.block_hash == block_hash)


@pytest.mark.parametrize("num_blocks", [16])
def test_eviction(num_blocks: int, ):
block_size = 16
block_allocator = BlockAllocator(Device.CPU,
block_size,
num_blocks,
enable_caching=True)
blocks = []

for i in range(num_blocks):
# use i as the block_hash
blocks.append(block_allocator.allocate(i, 0))

#Free all blocks
for block in blocks:
block_allocator.free(block)

# Allocate a new block and confirm that it's the first block freed. I.E The Least Recently Used block
new_block_hash = block_size
new_block = block_allocator.allocate(new_block_hash, 0)
assert (new_block == blocks[0])
assert (new_block.block_hash == new_block_hash)

# Reallocate the second in blocks to remove it from the free list
realloc_block_hash = 1
realloc_block = block_allocator.allocate(realloc_block_hash, 0)
assert (realloc_block == blocks[realloc_block_hash])
assert (realloc_block.block_hash == realloc_block_hash)

# Allocate a new block and confirm that it's not the realloc_block, since the realloc_block shouldn't be in the free list
new_block_hash = block_size + 1
new_block = block_allocator.allocate(new_block_hash, 0)
assert (realloc_block != new_block)
assert (new_block.block_hash == new_block_hash)
assert (new_block.block_number == 2)
76 changes: 76 additions & 0 deletions tests/test_cache_block_hashing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Test hashing of cache blocks.
Run `pytest tests/test_cache_block_hashing.py`.
"""
import pytest

from vllm.transformers_utils.tokenizer import TokenizerGroup
from vllm.sequence import Sequence

# Make two prefixes with different first blocks.
prefix_start = [("You are an expert"), ("You are a")]
prefix_common = (
" school principal, skilled in effectively managing "
"faculty and staff. Draft 10-15 questions for a potential first grade "
"Head Teacher for my K-12, all-girls', independent school that emphasizes "
"community, joyful discovery, and life-long learning. The candidate is "
"coming in for a first-round panel interview for a 8th grade Math "
"teaching role. They have 5 years of previous teaching experience "
"as an assistant teacher at a co-ed, public school with experience "
"in middle school math teaching. Based on this, fulfill "
"the following: ")
prefixes = [start + prefix_common for start in prefix_start]

# Sample prompts.
sample_prompts = [
"Hello, my name is", "The president of the United States is",
"The capital of France is", "The future of AI is"
]


# Helper function.
def flatten_2d(li):
return [lss for ls in li for lss in ls]


@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("max_num_seqs", [256])
def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int):

tokenizer = TokenizerGroup(
tokenizer_id="facebook/opt-125m",
enable_lora=False,
max_num_seqs=max_num_seqs,
max_input_length=None,
)

hashes = []

for prefix in prefixes:
hashes.append([])
prompts = [prefix + prompt for prompt in sample_prompts]
seq_id = 0
for prompt in prompts:
hashes[-1].append([])
prompt_token_ids = tokenizer.encode(prompt)
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)

num_blocks = len(prompt_token_ids) // block_size
for idx in range(num_blocks):
hashes[-1][-1].append(seq.hash_of_block(idx))

seq_id += 1

# Check that hashes made with two prefixes with different first blocks are
# different everywhere.
for hash0, hash1 in zip(flatten_2d(hashes[0]), flatten_2d(hashes[1])):
assert (hash0 != hash1)

# Check that hashes of different prompts made with the same prefix are the
# same until the hashes that contain the prompt.
for hash_pref in hashes:
same_hashes = [tuple(h[:-1]) for h in hash_pref]
different_hashes = [h[-1] for h in hash_pref]
assert (len(set(same_hashes)) == 1)
assert (len(set(different_hashes)) == len(different_hashes))
14 changes: 13 additions & 1 deletion vllm/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

_BLANK_TOKEN_ID = -1

DEFAULT_LAST_ACCESSED_TIME = -1


class LogicalTokenBlock:
"""A block that stores a contiguous chunk of tokens from left to right.
Expand Down Expand Up @@ -55,17 +57,27 @@ def __init__(
device: Device,
block_number: int,
block_size: int,
block_hash: int,
num_hashed_tokens: int,
) -> None:
self.device = device
self.block_number = block_number
self.block_size = block_size
self.block_hash = block_hash
self.num_hashed_tokens = num_hashed_tokens

self.ref_count = 0
self.last_accessed = DEFAULT_LAST_ACCESSED_TIME

self.computed = False

def __repr__(self) -> str:
return (f'PhysicalTokenBlock(device={self.device}, '
f'block_number={self.block_number}, '
f'ref_count={self.ref_count})')
f'num_hashed_tokens={self.num_hashed_tokens}, '
f'ref_count={self.ref_count}, '
f'last_accessed={self.last_accessed}, '
f'computed={self.computed})')


# Mapping: logical block number -> physical block.
Expand Down
2 changes: 2 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,12 +303,14 @@ def __init__(
swap_space: int,
cache_dtype: str,
sliding_window: Optional[int] = None,
enable_prefix_caching: bool = False,
) -> None:
self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization
self.swap_space_bytes = swap_space * _GB
self.cache_dtype = cache_dtype
self.sliding_window = sliding_window
self.enable_prefix_caching = enable_prefix_caching
self._verify_args()
self._verify_cache_dtype()

Expand Down
Loading

0 comments on commit ce4f5a2

Please sign in to comment.