forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Experimental] Prefix Caching Support (vllm-project#1669)
Co-authored-by: DouHappy <[email protected]> Co-authored-by: Zhuohan Li <[email protected]>
- Loading branch information
1 parent
14cc317
commit d10f8e1
Showing
20 changed files
with
1,356 additions
and
71 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from vllm import LLM, SamplingParams | ||
|
||
prefix = ( | ||
"You are an expert school principal, skilled in effectively managing " | ||
"faculty and staff. Draft 10-15 questions for a potential first grade " | ||
"Head Teacher for my K-12, all-girls', independent school that emphasizes " | ||
"community, joyful discovery, and life-long learning. The candidate is " | ||
"coming in for a first-round panel interview for a 8th grade Math " | ||
"teaching role. They have 5 years of previous teaching experience " | ||
"as an assistant teacher at a co-ed, public school with experience " | ||
"in middle school math teaching. Based on these information, fulfill " | ||
"the following paragraph: ") | ||
|
||
# Sample prompts. | ||
prompts = [ | ||
"Hello, my name is", | ||
"The president of the United States is", | ||
"The capital of France is", | ||
"The future of AI is", | ||
] | ||
# Create a sampling params object. | ||
sampling_params = SamplingParams(temperature=0.0) | ||
|
||
# Create an LLM. | ||
llm = LLM(model="facebook/opt-125m") | ||
|
||
generating_prompts = [prefix + prompt for prompt in prompts] | ||
|
||
# Generate texts from the prompts. The output is a list of RequestOutput objects | ||
# that contain the prompt, generated text, and other information. | ||
outputs = llm.generate(generating_prompts, sampling_params) | ||
# Print the outputs. | ||
for output in outputs: | ||
prompt = output.prompt | ||
generated_text = output.outputs[0].text | ||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") | ||
|
||
print("-" * 80) | ||
|
||
# -1 since the last token can change when concatenating prompts. | ||
prefix_pos = len(llm.llm_engine.tokenizer.encode(prefix)) - 1 | ||
|
||
# Generate with prefix | ||
outputs = llm.generate(generating_prompts, sampling_params, | ||
prefix_pos=[prefix_pos] * len(generating_prompts)) | ||
|
||
# Print the outputs. You should see the same outputs as before | ||
for output in outputs: | ||
prompt = output.prompt | ||
generated_text = output.outputs[0].text | ||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
import random | ||
import pytest | ||
import time | ||
|
||
import torch | ||
from vllm.model_executor.layers.triton_kernel.prefix_prefill import ( | ||
context_attention_fwd) | ||
from xformers import ops as xops | ||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask | ||
|
||
NUM_HEADS = [12] | ||
HEAD_SIZES = [128] | ||
DTYPES = [torch.float16] | ||
|
||
|
||
@pytest.mark.parametrize("num_heads", NUM_HEADS) | ||
@pytest.mark.parametrize("head_size", HEAD_SIZES) | ||
@pytest.mark.parametrize("dtype", DTYPES) | ||
@torch.inference_mode() | ||
def test_contexted_kv_attention( | ||
num_heads: int, | ||
head_size: int, | ||
dtype: torch.dtype, | ||
) -> None: | ||
random.seed(0) | ||
torch.manual_seed(0) | ||
MAX_SEQ_LEN = 1024 | ||
MAX_CTX_LEN = 1024 | ||
BS = 10 | ||
cache_size = 640 | ||
block_size = 32 | ||
max_block_per_request = 64 | ||
subquery_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] | ||
ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)] | ||
seq_lens = [a + b for a, b in zip(subquery_lens, ctx_lens)] | ||
|
||
num_tokens = sum(subquery_lens) | ||
query = torch.empty(num_tokens, | ||
num_heads, | ||
head_size, | ||
dtype=dtype, | ||
device='cuda') | ||
query.uniform_(-1e-3, 1e-3) | ||
output = torch.empty(num_tokens, | ||
num_heads, | ||
head_size, | ||
dtype=dtype, | ||
device='cuda') | ||
|
||
kv = torch.empty(sum(seq_lens), | ||
2, | ||
num_heads, | ||
head_size, | ||
dtype=dtype, | ||
device='cuda') | ||
kv.uniform_(-1e-3, 1e-3) | ||
key, value = kv.unbind(dim=1) | ||
|
||
k_cache = torch.zeros(cache_size, | ||
block_size, | ||
num_heads, | ||
head_size, | ||
dtype=dtype, | ||
device='cuda') | ||
v_cache = torch.zeros(cache_size, | ||
block_size, | ||
num_heads, | ||
head_size, | ||
dtype=dtype, | ||
device='cuda') | ||
k = torch.zeros(sum(subquery_lens), | ||
num_heads, | ||
head_size, | ||
dtype=dtype, | ||
device='cuda') | ||
v = torch.zeros(sum(subquery_lens), | ||
num_heads, | ||
head_size, | ||
dtype=dtype, | ||
device='cuda') | ||
values = torch.arange(0, cache_size, dtype=torch.long, device='cuda') | ||
values = values[torch.randperm(cache_size)] | ||
block_table = values[:BS * max_block_per_request].view( | ||
BS, max_block_per_request) | ||
b_seq_len = torch.tensor(seq_lens, dtype=torch.long, device='cuda') | ||
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long, device='cuda') | ||
b_start_loc = torch.cumsum(torch.tensor([0] + subquery_lens[:-1], | ||
dtype=torch.long, | ||
device='cuda'), | ||
dim=0) | ||
max_input_len = MAX_SEQ_LEN | ||
# copy kv to cache | ||
b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1], | ||
dtype=torch.long, | ||
device='cuda'), | ||
dim=0) | ||
for i in range(BS): | ||
for j in range(subquery_lens[i]): | ||
k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + | ||
j]) | ||
v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + | ||
b_ctx_len[i] + j]) | ||
cur_ctx = 0 | ||
block_id = 0 | ||
while cur_ctx < b_ctx_len[i]: | ||
start_loc = b_seq_start_loc[i] + cur_ctx | ||
if cur_ctx + block_size > b_ctx_len[i]: | ||
end_loc = b_seq_start_loc[i] + b_ctx_len[i] | ||
else: | ||
end_loc = start_loc + block_size | ||
start_slot = block_table[i, block_id] * block_size | ||
end_slot = start_slot + end_loc - start_loc | ||
k_cache.view(-1, num_heads, head_size)[start_slot:end_slot].copy_( | ||
key[start_loc:end_loc]) | ||
v_cache.view(-1, num_heads, head_size)[start_slot:end_slot].copy_( | ||
value[start_loc:end_loc]) | ||
cur_ctx += block_size | ||
block_id += 1 | ||
# transpose K_cache[num_blocks, block_size, num_kv_heads, head_size] | ||
# to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8] | ||
k_cache = k_cache.view(-1, block_size, num_heads, head_size // 8, | ||
8).permute(0, 2, 3, 1, 4).contiguous() | ||
# transpose V_cache[num_blocks, block_size, num_kv_heads, head_size] | ||
# to V_cache[num_blocks, num_kv_heads, head_size, block_size] | ||
v_cache = v_cache.view(-1, block_size, num_heads, | ||
head_size).permute(0, 2, 3, 1).contiguous() | ||
|
||
context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table, | ||
b_start_loc, b_seq_len, b_ctx_len, max_input_len) | ||
torch.cuda.synchronize() | ||
start_time = time.time() | ||
context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table, | ||
b_start_loc, b_seq_len, b_ctx_len, max_input_len) | ||
torch.cuda.synchronize() | ||
end_time = time.time() | ||
print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") | ||
|
||
scale = float(1.0 / (head_size**0.5)) | ||
|
||
attn_op = xops.fmha.cutlass.FwOp() | ||
|
||
attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens( | ||
subquery_lens, seq_lens) | ||
output_ref = xops.memory_efficient_attention_forward( | ||
query.unsqueeze(0), | ||
key.unsqueeze(0), | ||
value.unsqueeze(0), | ||
attn_bias=attn_bias, | ||
p=0.0, | ||
scale=scale, | ||
op=attn_op, | ||
) | ||
torch.cuda.synchronize() | ||
start_time = time.time() | ||
output_ref = xops.memory_efficient_attention_forward( | ||
query.unsqueeze(0), | ||
key.unsqueeze(0), | ||
value.unsqueeze(0), | ||
attn_bias=attn_bias, | ||
p=0.0, | ||
scale=scale, | ||
op=attn_op, | ||
) | ||
torch.cuda.synchronize() | ||
end_time = time.time() | ||
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") | ||
output_ref = output_ref.squeeze(0) | ||
assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
"""Compare the with and without prefix caching. | ||
Run `pytest tests/prefix_caching/test_prefix_caching.py`. | ||
""" | ||
import pytest | ||
|
||
from vllm import LLM, SamplingParams | ||
|
||
prefix = ( | ||
"You are an expert school principal, skilled in effectively managing " | ||
"faculty and staff. Draft 10-15 questions for a potential first grade " | ||
"Head Teacher for my K-12, all-girls', independent school that emphasizes " | ||
"community, joyful discovery, and life-long learning. The candidate is " | ||
"coming in for a first-round panel interview for a 8th grade Math " | ||
"teaching role. They have 5 years of previous teaching experience " | ||
"as an assistant teacher at a co-ed, public school with experience " | ||
"in middle school math teaching. Based on these information, fulfill " | ||
"the following paragraph: ") | ||
|
||
|
||
@pytest.mark.parametrize("model", ["facebook/opt-125m"]) | ||
@pytest.mark.parametrize("max_tokens", [16]) | ||
def test_prefix_caching( | ||
example_prompts, | ||
model: str, | ||
max_tokens: int, | ||
): | ||
llm = LLM(model=model) | ||
# -1 since the last token can change when concatenating prompts. | ||
prefix_pos = len(llm.llm_engine.tokenizer.encode(prefix)) - 1 | ||
prompts = [prefix + prompt for prompt in example_prompts] | ||
sampling_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) | ||
outputs_without_prefix = llm.generate(prompts, sampling_params) | ||
outputs_with_prefix = llm.generate(prompts, | ||
sampling_params, | ||
prefix_pos=[prefix_pos] * len(prompts)) | ||
for output_without_prefix, output_with_prefix in zip( | ||
outputs_without_prefix, outputs_with_prefix): | ||
assert (output_without_prefix.outputs[0].token_ids == | ||
output_with_prefix.outputs[0].token_ids) | ||
assert len(llm.llm_engine.scheduler.prefix_pool.prefixes) == 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.