Skip to content

Commit

Permalink
Scratch prefix caching (#218)
Browse files Browse the repository at this point in the history
Signed-off-by: Onkar Chougule <[email protected]>
  • Loading branch information
ochougul authored Jan 11, 2025
1 parent b845f8e commit 314009e
Show file tree
Hide file tree
Showing 7 changed files with 212 additions and 9 deletions.
30 changes: 24 additions & 6 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def export(self, export_dir: Optional[str] = None) -> str:
2: "ctx_len",
}
output_names = ["logits"]

for i in range(self.num_layers):
for kv in ["key", "value"]:
example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32))
Expand Down Expand Up @@ -240,6 +241,7 @@ def compile(
ctx_len: int = 128,
batch_size: int = 1,
full_batch_size: Optional[int] = None,
kv_cache_batch_size: Optional[int] = None,
num_devices: int = 1,
num_cores: int = 16, # FIXME: Make this mandatory arg
mxfp6_matmul: bool = False,
Expand Down Expand Up @@ -291,15 +293,28 @@ def compile(
if self.continuous_batching and full_batch_size is None:
raise TypeError("missing required argument: 'full_batch_size'")

if kv_cache_batch_size and not full_batch_size:
raise ValueError(
"Prefix caching is enabled only for continuous batching as of now. Please pass `full_batch_size` argument and make sure you pass `continuous_batching=True` in the `from_pretrained` call"
)

kv_cache_batch_size = (
kv_cache_batch_size if kv_cache_batch_size else (full_batch_size if full_batch_size else batch_size)
)
# Define prefill specialization
prefill_specialization = {
# Prefill is always run with single BS for continuous batching.
"batch_size": 1 if self.continuous_batching else batch_size,
"seq_len": prefill_seq_len,
"ctx_len": ctx_len,
# TODO: should be renamed to kv_cache_batch_size in specialzation too
}
prefill_specialization.update({"full_batch_size": full_batch_size}) if self.continuous_batching else None
prefill_specialization.update({"num_logits_to_keep": 1}) if self.is_tlm else None
prefill_specialization.update({"num_logits_to_keep": 1}) if self.is_tlm else ...
if self.continuous_batching:
prefill_specialization.update({"full_batch_size": kv_cache_batch_size})
else:
prefill_specialization.update({"batch_size": kv_cache_batch_size})
prefill_specialization.update({"full_batch_exec_size": full_batch_size}) if full_batch_size else ...
specializations = [
prefill_specialization,
]
Expand All @@ -311,8 +326,11 @@ def compile(
"seq_len": num_speculative_tokens + 1 if self.is_tlm else 1,
"ctx_len": ctx_len,
}
decode_specialization.update({"full_batch_size": full_batch_size}) if self.continuous_batching else None
decode_specialization.update({"num_logits_to_keep": num_speculative_tokens + 1}) if self.is_tlm else None
if self.continuous_batching:
decode_specialization.update({"full_batch_size": kv_cache_batch_size})
else:
decode_specialization.update({"batch_size": kv_cache_batch_size})
decode_specialization.update({"num_logits_to_keep": num_speculative_tokens + 1}) if self.is_tlm else ...
specializations.append(decode_specialization)

if enable_qnn:
Expand Down Expand Up @@ -363,7 +381,7 @@ def generate(
self,
tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer],
prompts: List[str],
device_id: List[int] = [0],
device_id: List[int] = None,
runtime_ai100: bool = True,
**kwargs,
):
Expand Down Expand Up @@ -569,7 +587,7 @@ def compile(
def generate(
self,
inputs: torch.Tensor,
device_ids: List[int] = [0],
device_ids: List[int] = None,
runtime_ai100: bool = True,
) -> Union[torch.Tensor, np.ndarray]:
"""
Expand Down
4 changes: 2 additions & 2 deletions scripts/Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ pipeline {
}
stage('Run Non-CLI QAIC Tests') {
steps {
timeout(time: 70, unit: 'MINUTES') {
timeout(time: 200, unit: 'MINUTES') {
sh '''
sudo docker exec ${BUILD_TAG} bash -c "
cd /efficient-transformers &&
. preflight_qeff/bin/activate &&
mkdir -p $PWD/Non_qaic &&
export TOKENIZERS_PARALLELISM=false &&
export QEFF_HOME=$PWD/Non_qaic &&
pytest tests -m '(not cli) and (on_qaic) and (not qnn)' -n 3 --junitxml=tests/tests_log2.xml &&
pytest tests -m '(not cli) and (on_qaic) and (not qnn)' -n 4 --junitxml=tests/tests_log2.xml &&
deactivate"
'''
}
Expand Down
1 change: 1 addition & 0 deletions tests/finetune/test_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def clean_up(path):


# TODO:enable this once docker is available
@pytest.mark.on_qaic
@pytest.mark.skip(reason="eager docker not available in sdk")
@pytest.mark.parametrize(
"model_name,max_eval_step,max_train_step,intermediate_step_save,context_length,run_validation,use_peft,device",
Expand Down
1 change: 0 additions & 1 deletion tests/peft/lora/test_lora_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,5 @@ def test_auto_lora_model_for_causal_lm_export_compile_generate(base_model_name,
qeff_model.generate(
tokenizer=load_hf_tokenizer(pretrained_model_name_or_path=base_model_name),
prompts=prompts,
device_id=[0],
prompt_to_adapter_mapping=["adapter_0", "adapter_1", "adapter_0", "base"],
)
1 change: 1 addition & 0 deletions tests/text_generation/test_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def load_causal_lm_model(model_config):


# Use @pytest.mark.parametrize to apply the configurations
@pytest.mark.on_qaic
@pytest.mark.parametrize("model_name, n_layer, full_batch_size, max_gen_len", configs)
def test_generate_text_stream(
model_name: str,
Expand Down
183 changes: 183 additions & 0 deletions tests/transformers/models/test_prefix_caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

import numpy as np
import pytest
from transformers import AutoTokenizer

from QEfficient.generation.text_generation_inference import TextGeneration
from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM

test_models = ["gpt2"]


# The test should first generate output with some prefix+suffix1 or batch_id and then confirm that we are still able to execute of prefix+suffix2 on same batch id and getting correct output.
@pytest.mark.on_qaic
@pytest.mark.parametrize("model_name", test_models)
def test_simple_prefix_caching(model_name):
qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_name, continuous_batching=True)
qeff_model.compile(
prefill_seq_len=128,
ctx_len=256,
full_batch_size=2,
kv_cache_batch_size=4,
num_cores=14,
)

prefixes = ["Once upon a time ", "Once upon a time "]
suffixes1 = ["in a land far away", "there was a small village"]
suffixes2 = ["a little girl", "in a bustling city"]

tokenizer = AutoTokenizer.from_pretrained(model_name)

generator = TextGeneration(tokenizer=tokenizer, qpc_path=qeff_model.qpc_path, full_batch_size=2, ctx_len=256)

prompts = [pref + suff for pref, suff in zip(prefixes, suffixes1)]

# generation for batch_indices = 0, 1
prompts_exec_info = generator.generate(prompts)
##############################
# generation for batch_indices
##############################
# Run prefill for indices 2, 3 with same prompts
out2, pos2, gen_len2 = generator._qaic_model.run_prefill(
prompts[0], generation_len=None, decode_batch_id=np.array(2, dtype=np.int64).reshape(1, 1)
)
out3, pos3, gen_len3 = generator._qaic_model.run_prefill(
prompts[1], generation_len=None, decode_batch_id=np.array(3, dtype=np.int64).reshape(1, 1)
)

# Run decode for batch indices 2, 3
decode_inputs = {
"input_ids": np.array([[out2["logits"].argmax(2)[0][0]], [out3["logits"].argmax(2)[0][0]]]),
"position_ids": np.array([[pos2[0][0]], [pos3[0][0]]]),
"batch_index": np.array([[2], [3]], dtype=np.int64),
}

# Set logits placeholder for decode
logits_out_placeholder = np.zeros(
(
generator._qaic_model.full_batch_size,
generator._qaic_model._decode_seq_len,
generator._qaic_model._vocab_size,
),
dtype=np.float32,
)
generator._qaic_model._session.set_buffers({"logits": logits_out_placeholder})

generation_outputs = []
for i in range(gen_len2):
generation_outputs.append(decode_inputs["input_ids"])
outputs = generator._qaic_model._session.run(decode_inputs)
logits = outputs["logits"]
if len(logits.shape) == 2:
logits = np.expand_dims(logits, 1)
next_token_id = logits.argmax(2)

decode_inputs["input_ids"] = next_token_id
decode_inputs["position_ids"] += 1

assert np.all(generator._qaic_model.generated_ids[0, :gen_len2] == [int(val[0]) for val in generation_outputs])
assert np.all(generator._qaic_model.generated_ids[1, :gen_len2] == [int(val[1]) for val in generation_outputs])

##############################
# Now rerun with cached prefix on 0th index with prompt3 and use -1 for 1st index
##############################

nprompts = [pref + suff for pref, suff in zip(prefixes, suffixes2)]

## Prefill run on index 0
prompt = nprompts[0]
inputs = tokenizer(prompt, return_tensors="np", padding=True)
position_ids = inputs["attention_mask"].sum(1, keepdims=True)
padded_len = inputs["input_ids"].shape[1]
num_chunks = -(padded_len // -generator._qaic_model._prefill_seq_len)
padded_len = num_chunks * generator._qaic_model._prefill_seq_len # Convert to a multiple of prompt_len

# Initialize variables specific to request
# Calculate the max generation length.
max_gen_len = generator._qaic_model._ctx_len - position_ids.max()

# Set the prefill logic buffer
logits_out_placeholder = np.zeros((1, 1, generator._qaic_model._vocab_size), dtype=np.float32)
generator._qaic_model._session.set_buffers({"logits": logits_out_placeholder})
inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len)
inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1)
inputs.pop("token_type_ids", None)
inputs["batch_index"] = np.array([[0]], dtype=np.int64)
norm_outputs = generator._qaic_model._session.run(inputs)
inputs["input_ids"][:, :3] = inputs["input_ids"][:, 4:7]
inputs["input_ids"][:, 3:] = 50256
inputs["position_ids"][:, :3] = inputs["position_ids"][:, 4:7]
inputs["position_ids"][:, 3:] = -1
mod_outputs = generator._qaic_model._session.run(inputs)
assert (mod_outputs["logits"] == norm_outputs["logits"]).all()
decode_inputs = {
"input_ids": np.array([[mod_outputs["logits"].argmax(2)[0][0]], [0]]),
"position_ids": np.array([[position_ids[0][0]], [-1]]),
"batch_index": np.array([[0], [1]], dtype=np.int64),
}

# Set logits placeholder for decode
logits_out_placeholder = np.zeros(
(
generator._qaic_model.full_batch_size,
generator._qaic_model._decode_seq_len,
generator._qaic_model._vocab_size,
),
dtype=np.float32,
)
generator._qaic_model._session.set_buffers({"logits": logits_out_placeholder})

generation_outputs = []
for i in range(max_gen_len):
generation_outputs.append(decode_inputs["input_ids"])
outputs = generator._qaic_model._session.run(decode_inputs)
logits = outputs["logits"]
if len(logits.shape) == 2:
logits = np.expand_dims(logits, 1)
next_token_id = logits.argmax(2)

decode_inputs["input_ids"] = next_token_id
decode_inputs["position_ids"][0][0] += 1

# TODO: add a check if this matches normal execution for same prompt
##############
# Now run decode on 1st index again with mod_inputs and check if output is correct
##############
decode_inputs = {
"input_ids": np.array([[0], [prompts_exec_info.generated_ids[1][0]]]),
"position_ids": np.array([[-1], [9]]),
"batch_index": np.array([[0], [1]], dtype=np.int64),
}

# Set logits placeholder for decode
logits_out_placeholder = np.zeros(
(
generator._qaic_model.full_batch_size,
generator._qaic_model._decode_seq_len,
generator._qaic_model._vocab_size,
),
dtype=np.float32,
)
generator._qaic_model._session.set_buffers({"logits": logits_out_placeholder})

generation_outputs_prefill_cached = []
for i in range(max_gen_len):
generation_outputs_prefill_cached.append(decode_inputs["input_ids"])
outputs = generator._qaic_model._session.run(decode_inputs)
logits = outputs["logits"]
if len(logits.shape) == 2:
logits = np.expand_dims(logits, 1)
next_token_id = logits.argmax(2)

decode_inputs["input_ids"] = next_token_id
decode_inputs["position_ids"][1][0] += 1

assert np.all(
prompts_exec_info.generated_ids[1][:247] == [int(val[1]) for val in generation_outputs_prefill_cached][:247]
)
1 change: 1 addition & 0 deletions tests/transformers/spd/test_spd_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def split_dlm_bonus_token_inputs(dlm_decode_inputs):
return bonus_token_inputs, dlm_decode_inputs


@pytest.mark.on_qaic
@pytest.mark.parametrize(
"prompt, num_speculative_tokens, prefill_seq_len, ctx_len, prefill_bsz, draft_model_name, target_model_name, full_batch_size",
configs,
Expand Down

0 comments on commit 314009e

Please sign in to comment.