Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scratch prefix caching #218

Merged
merged 17 commits into from
Jan 11, 2025
Merged
30 changes: 24 additions & 6 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def export(self, export_dir: Optional[str] = None) -> str:
2: "ctx_len",
}
output_names = ["logits"]

for i in range(self.num_layers):
for kv in ["key", "value"]:
example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32))
Expand Down Expand Up @@ -240,6 +241,7 @@ def compile(
ctx_len: int = 128,
batch_size: int = 1,
full_batch_size: Optional[int] = None,
kv_cache_batch_size: Optional[int] = None,
num_devices: int = 1,
num_cores: int = 16, # FIXME: Make this mandatory arg
mxfp6_matmul: bool = False,
Expand Down Expand Up @@ -291,15 +293,28 @@ def compile(
if self.continuous_batching and full_batch_size is None:
raise TypeError("missing required argument: 'full_batch_size'")

if kv_cache_batch_size and not full_batch_size:
raise ValueError(
"Prefix caching is enabled only for continuous batching as of now. Please pass `full_batch_size` argument and make sure you pass `continuous_batching=True` in the `from_pretrained` call"
)

kv_cache_batch_size = (
kv_cache_batch_size if kv_cache_batch_size else (full_batch_size if full_batch_size else batch_size)
)
# Define prefill specialization
prefill_specialization = {
# Prefill is always run with single BS for continuous batching.
"batch_size": 1 if self.continuous_batching else batch_size,
"seq_len": prefill_seq_len,
"ctx_len": ctx_len,
# TODO: should be renamed to kv_cache_batch_size in specialzation too
}
prefill_specialization.update({"full_batch_size": full_batch_size}) if self.continuous_batching else None
prefill_specialization.update({"num_logits_to_keep": 1}) if self.is_tlm else None
prefill_specialization.update({"num_logits_to_keep": 1}) if self.is_tlm else ...
if self.continuous_batching:
prefill_specialization.update({"full_batch_size": kv_cache_batch_size})
else:
prefill_specialization.update({"batch_size": kv_cache_batch_size})
prefill_specialization.update({"full_batch_exec_size": full_batch_size}) if full_batch_size else ...
specializations = [
prefill_specialization,
]
Expand All @@ -311,8 +326,11 @@ def compile(
"seq_len": num_speculative_tokens + 1 if self.is_tlm else 1,
"ctx_len": ctx_len,
}
decode_specialization.update({"full_batch_size": full_batch_size}) if self.continuous_batching else None
decode_specialization.update({"num_logits_to_keep": num_speculative_tokens + 1}) if self.is_tlm else None
if self.continuous_batching:
decode_specialization.update({"full_batch_size": kv_cache_batch_size})
else:
decode_specialization.update({"batch_size": kv_cache_batch_size})
decode_specialization.update({"num_logits_to_keep": num_speculative_tokens + 1}) if self.is_tlm else ...
specializations.append(decode_specialization)

if enable_qnn:
Expand Down Expand Up @@ -363,7 +381,7 @@ def generate(
self,
tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer],
prompts: List[str],
device_id: List[int] = [0],
device_id: List[int] = None,
runtime_ai100: bool = True,
**kwargs,
):
Expand Down Expand Up @@ -569,7 +587,7 @@ def compile(
def generate(
self,
inputs: torch.Tensor,
device_ids: List[int] = [0],
device_ids: List[int] = None,
runtime_ai100: bool = True,
) -> Union[torch.Tensor, np.ndarray]:
"""
Expand Down
4 changes: 2 additions & 2 deletions scripts/Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ pipeline {
}
stage('Run Non-CLI QAIC Tests') {
steps {
timeout(time: 70, unit: 'MINUTES') {
timeout(time: 200, unit: 'MINUTES') {
sh '''
sudo docker exec ${BUILD_TAG} bash -c "
cd /efficient-transformers &&
. preflight_qeff/bin/activate &&
mkdir -p $PWD/Non_qaic &&
export TOKENIZERS_PARALLELISM=false &&
export QEFF_HOME=$PWD/Non_qaic &&
pytest tests -m '(not cli) and (on_qaic) and (not qnn)' -n 3 --junitxml=tests/tests_log2.xml &&
pytest tests -m '(not cli) and (on_qaic) and (not qnn)' -n 4 --junitxml=tests/tests_log2.xml &&
deactivate"
'''
}
Expand Down
1 change: 1 addition & 0 deletions tests/finetune/test_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def clean_up(path):


# TODO:enable this once docker is available
@pytest.mark.on_qaic
@pytest.mark.skip(reason="eager docker not available in sdk")
@pytest.mark.parametrize(
"model_name,max_eval_step,max_train_step,intermediate_step_save,context_length,run_validation,use_peft,device",
Expand Down
1 change: 0 additions & 1 deletion tests/peft/lora/test_lora_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,5 @@ def test_auto_lora_model_for_causal_lm_export_compile_generate(base_model_name,
qeff_model.generate(
tokenizer=load_hf_tokenizer(pretrained_model_name_or_path=base_model_name),
prompts=prompts,
device_id=[0],
prompt_to_adapter_mapping=["adapter_0", "adapter_1", "adapter_0", "base"],
)
1 change: 1 addition & 0 deletions tests/text_generation/test_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def load_causal_lm_model(model_config):


# Use @pytest.mark.parametrize to apply the configurations
@pytest.mark.on_qaic
@pytest.mark.parametrize("model_name, n_layer, full_batch_size, max_gen_len", configs)
def test_generate_text_stream(
model_name: str,
Expand Down
183 changes: 183 additions & 0 deletions tests/transformers/models/test_prefix_caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

import numpy as np
import pytest
from transformers import AutoTokenizer

from QEfficient.generation.text_generation_inference import TextGeneration
from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM

test_models = ["gpt2"]


# The test should first generate output with some prefix+suffix1 or batch_id and then confirm that we are still able to execute of prefix+suffix2 on same batch id and getting correct output.
@pytest.mark.on_qaic
@pytest.mark.parametrize("model_name", test_models)
def test_simple_prefix_caching(model_name):
qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_name, continuous_batching=True)
qeff_model.compile(
prefill_seq_len=128,
ctx_len=256,
full_batch_size=2,
kv_cache_batch_size=4,
num_cores=14,
)

prefixes = ["Once upon a time ", "Once upon a time "]
suffixes1 = ["in a land far away", "there was a small village"]
suffixes2 = ["a little girl", "in a bustling city"]

tokenizer = AutoTokenizer.from_pretrained(model_name)

generator = TextGeneration(tokenizer=tokenizer, qpc_path=qeff_model.qpc_path, full_batch_size=2, ctx_len=256)

prompts = [pref + suff for pref, suff in zip(prefixes, suffixes1)]

# generation for batch_indices = 0, 1
prompts_exec_info = generator.generate(prompts)
##############################
# generation for batch_indices
##############################
# Run prefill for indices 2, 3 with same prompts
out2, pos2, gen_len2 = generator._qaic_model.run_prefill(
prompts[0], generation_len=None, decode_batch_id=np.array(2, dtype=np.int64).reshape(1, 1)
)
out3, pos3, gen_len3 = generator._qaic_model.run_prefill(
prompts[1], generation_len=None, decode_batch_id=np.array(3, dtype=np.int64).reshape(1, 1)
)

# Run decode for batch indices 2, 3
decode_inputs = {
"input_ids": np.array([[out2["logits"].argmax(2)[0][0]], [out3["logits"].argmax(2)[0][0]]]),
"position_ids": np.array([[pos2[0][0]], [pos3[0][0]]]),
"batch_index": np.array([[2], [3]], dtype=np.int64),
}

# Set logits placeholder for decode
logits_out_placeholder = np.zeros(
(
generator._qaic_model.full_batch_size,
generator._qaic_model._decode_seq_len,
generator._qaic_model._vocab_size,
),
dtype=np.float32,
)
generator._qaic_model._session.set_buffers({"logits": logits_out_placeholder})

generation_outputs = []
for i in range(gen_len2):
generation_outputs.append(decode_inputs["input_ids"])
outputs = generator._qaic_model._session.run(decode_inputs)
logits = outputs["logits"]
if len(logits.shape) == 2:
logits = np.expand_dims(logits, 1)
next_token_id = logits.argmax(2)

decode_inputs["input_ids"] = next_token_id
decode_inputs["position_ids"] += 1

assert np.all(generator._qaic_model.generated_ids[0, :gen_len2] == [int(val[0]) for val in generation_outputs])
assert np.all(generator._qaic_model.generated_ids[1, :gen_len2] == [int(val[1]) for val in generation_outputs])

##############################
# Now rerun with cached prefix on 0th index with prompt3 and use -1 for 1st index
##############################

nprompts = [pref + suff for pref, suff in zip(prefixes, suffixes2)]

## Prefill run on index 0
prompt = nprompts[0]
inputs = tokenizer(prompt, return_tensors="np", padding=True)
position_ids = inputs["attention_mask"].sum(1, keepdims=True)
padded_len = inputs["input_ids"].shape[1]
num_chunks = -(padded_len // -generator._qaic_model._prefill_seq_len)
padded_len = num_chunks * generator._qaic_model._prefill_seq_len # Convert to a multiple of prompt_len

# Initialize variables specific to request
# Calculate the max generation length.
max_gen_len = generator._qaic_model._ctx_len - position_ids.max()

# Set the prefill logic buffer
logits_out_placeholder = np.zeros((1, 1, generator._qaic_model._vocab_size), dtype=np.float32)
generator._qaic_model._session.set_buffers({"logits": logits_out_placeholder})
inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len)
inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1)
inputs.pop("token_type_ids", None)
inputs["batch_index"] = np.array([[0]], dtype=np.int64)
norm_outputs = generator._qaic_model._session.run(inputs)
inputs["input_ids"][:, :3] = inputs["input_ids"][:, 4:7]
inputs["input_ids"][:, 3:] = 50256
inputs["position_ids"][:, :3] = inputs["position_ids"][:, 4:7]
inputs["position_ids"][:, 3:] = -1
mod_outputs = generator._qaic_model._session.run(inputs)
assert (mod_outputs["logits"] == norm_outputs["logits"]).all()
decode_inputs = {
"input_ids": np.array([[mod_outputs["logits"].argmax(2)[0][0]], [0]]),
"position_ids": np.array([[position_ids[0][0]], [-1]]),
"batch_index": np.array([[0], [1]], dtype=np.int64),
}

# Set logits placeholder for decode
logits_out_placeholder = np.zeros(
(
generator._qaic_model.full_batch_size,
generator._qaic_model._decode_seq_len,
generator._qaic_model._vocab_size,
),
dtype=np.float32,
)
generator._qaic_model._session.set_buffers({"logits": logits_out_placeholder})

generation_outputs = []
for i in range(max_gen_len):
generation_outputs.append(decode_inputs["input_ids"])
outputs = generator._qaic_model._session.run(decode_inputs)
logits = outputs["logits"]
if len(logits.shape) == 2:
logits = np.expand_dims(logits, 1)
next_token_id = logits.argmax(2)

decode_inputs["input_ids"] = next_token_id
decode_inputs["position_ids"][0][0] += 1

# TODO: add a check if this matches normal execution for same prompt
##############
# Now run decode on 1st index again with mod_inputs and check if output is correct
##############
decode_inputs = {
"input_ids": np.array([[0], [prompts_exec_info.generated_ids[1][0]]]),
"position_ids": np.array([[-1], [9]]),
"batch_index": np.array([[0], [1]], dtype=np.int64),
}

# Set logits placeholder for decode
logits_out_placeholder = np.zeros(
(
generator._qaic_model.full_batch_size,
generator._qaic_model._decode_seq_len,
generator._qaic_model._vocab_size,
),
dtype=np.float32,
)
generator._qaic_model._session.set_buffers({"logits": logits_out_placeholder})

generation_outputs_prefill_cached = []
for i in range(max_gen_len):
generation_outputs_prefill_cached.append(decode_inputs["input_ids"])
outputs = generator._qaic_model._session.run(decode_inputs)
logits = outputs["logits"]
if len(logits.shape) == 2:
logits = np.expand_dims(logits, 1)
next_token_id = logits.argmax(2)

decode_inputs["input_ids"] = next_token_id
decode_inputs["position_ids"][1][0] += 1

assert np.all(
prompts_exec_info.generated_ids[1][:247] == [int(val[1]) for val in generation_outputs_prefill_cached][:247]
)
1 change: 1 addition & 0 deletions tests/transformers/spd/test_spd_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def split_dlm_bonus_token_inputs(dlm_decode_inputs):
return bonus_token_inputs, dlm_decode_inputs


@pytest.mark.on_qaic
@pytest.mark.parametrize(
"prompt, num_speculative_tokens, prefill_seq_len, ctx_len, prefill_bsz, draft_model_name, target_model_name, full_batch_size",
configs,
Expand Down
Loading