diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index f565cbca9..ff657d29d 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -209,6 +209,7 @@ def export(self, export_dir: Optional[str] = None) -> str: 2: "ctx_len", } output_names = ["logits"] + for i in range(self.num_layers): for kv in ["key", "value"]: example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) @@ -240,6 +241,7 @@ def compile( ctx_len: int = 128, batch_size: int = 1, full_batch_size: Optional[int] = None, + kv_cache_batch_size: Optional[int] = None, num_devices: int = 1, num_cores: int = 16, # FIXME: Make this mandatory arg mxfp6_matmul: bool = False, @@ -291,15 +293,28 @@ def compile( if self.continuous_batching and full_batch_size is None: raise TypeError("missing required argument: 'full_batch_size'") + if kv_cache_batch_size and not full_batch_size: + raise ValueError( + "Prefix caching is enabled only for continuous batching as of now. Please pass `full_batch_size` argument and make sure you pass `continuous_batching=True` in the `from_pretrained` call" + ) + + kv_cache_batch_size = ( + kv_cache_batch_size if kv_cache_batch_size else (full_batch_size if full_batch_size else batch_size) + ) # Define prefill specialization prefill_specialization = { # Prefill is always run with single BS for continuous batching. "batch_size": 1 if self.continuous_batching else batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len, + # TODO: should be renamed to kv_cache_batch_size in specialzation too } - prefill_specialization.update({"full_batch_size": full_batch_size}) if self.continuous_batching else None - prefill_specialization.update({"num_logits_to_keep": 1}) if self.is_tlm else None + prefill_specialization.update({"num_logits_to_keep": 1}) if self.is_tlm else ... + if self.continuous_batching: + prefill_specialization.update({"full_batch_size": kv_cache_batch_size}) + else: + prefill_specialization.update({"batch_size": kv_cache_batch_size}) + prefill_specialization.update({"full_batch_exec_size": full_batch_size}) if full_batch_size else ... specializations = [ prefill_specialization, ] @@ -311,8 +326,11 @@ def compile( "seq_len": num_speculative_tokens + 1 if self.is_tlm else 1, "ctx_len": ctx_len, } - decode_specialization.update({"full_batch_size": full_batch_size}) if self.continuous_batching else None - decode_specialization.update({"num_logits_to_keep": num_speculative_tokens + 1}) if self.is_tlm else None + if self.continuous_batching: + decode_specialization.update({"full_batch_size": kv_cache_batch_size}) + else: + decode_specialization.update({"batch_size": kv_cache_batch_size}) + decode_specialization.update({"num_logits_to_keep": num_speculative_tokens + 1}) if self.is_tlm else ... specializations.append(decode_specialization) if enable_qnn: @@ -363,7 +381,7 @@ def generate( self, tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer], prompts: List[str], - device_id: List[int] = [0], + device_id: List[int] = None, runtime_ai100: bool = True, **kwargs, ): @@ -569,7 +587,7 @@ def compile( def generate( self, inputs: torch.Tensor, - device_ids: List[int] = [0], + device_ids: List[int] = None, runtime_ai100: bool = True, ) -> Union[torch.Tensor, np.ndarray]: """ diff --git a/scripts/Jenkinsfile b/scripts/Jenkinsfile index d1bb02a29..0d802b83f 100644 --- a/scripts/Jenkinsfile +++ b/scripts/Jenkinsfile @@ -48,7 +48,7 @@ pipeline { } stage('Run Non-CLI QAIC Tests') { steps { - timeout(time: 70, unit: 'MINUTES') { + timeout(time: 200, unit: 'MINUTES') { sh ''' sudo docker exec ${BUILD_TAG} bash -c " cd /efficient-transformers && @@ -56,7 +56,7 @@ pipeline { mkdir -p $PWD/Non_qaic && export TOKENIZERS_PARALLELISM=false && export QEFF_HOME=$PWD/Non_qaic && - pytest tests -m '(not cli) and (on_qaic) and (not qnn)' -n 3 --junitxml=tests/tests_log2.xml && + pytest tests -m '(not cli) and (on_qaic) and (not qnn)' -n 4 --junitxml=tests/tests_log2.xml && deactivate" ''' } diff --git a/tests/finetune/test_finetune.py b/tests/finetune/test_finetune.py index 4d7d061f1..45330cad6 100644 --- a/tests/finetune/test_finetune.py +++ b/tests/finetune/test_finetune.py @@ -26,6 +26,7 @@ def clean_up(path): # TODO:enable this once docker is available +@pytest.mark.on_qaic @pytest.mark.skip(reason="eager docker not available in sdk") @pytest.mark.parametrize( "model_name,max_eval_step,max_train_step,intermediate_step_save,context_length,run_validation,use_peft,device", diff --git a/tests/peft/lora/test_lora_model.py b/tests/peft/lora/test_lora_model.py index a91555b3a..a1bea6049 100644 --- a/tests/peft/lora/test_lora_model.py +++ b/tests/peft/lora/test_lora_model.py @@ -229,6 +229,5 @@ def test_auto_lora_model_for_causal_lm_export_compile_generate(base_model_name, qeff_model.generate( tokenizer=load_hf_tokenizer(pretrained_model_name_or_path=base_model_name), prompts=prompts, - device_id=[0], prompt_to_adapter_mapping=["adapter_0", "adapter_1", "adapter_0", "base"], ) diff --git a/tests/text_generation/test_text_generation.py b/tests/text_generation/test_text_generation.py index 15f4b7dcb..b8915859e 100644 --- a/tests/text_generation/test_text_generation.py +++ b/tests/text_generation/test_text_generation.py @@ -44,6 +44,7 @@ def load_causal_lm_model(model_config): # Use @pytest.mark.parametrize to apply the configurations +@pytest.mark.on_qaic @pytest.mark.parametrize("model_name, n_layer, full_batch_size, max_gen_len", configs) def test_generate_text_stream( model_name: str, diff --git a/tests/transformers/models/test_prefix_caching.py b/tests/transformers/models/test_prefix_caching.py new file mode 100644 index 000000000..fa79f33cd --- /dev/null +++ b/tests/transformers/models/test_prefix_caching.py @@ -0,0 +1,183 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import numpy as np +import pytest +from transformers import AutoTokenizer + +from QEfficient.generation.text_generation_inference import TextGeneration +from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + +test_models = ["gpt2"] + + +# The test should first generate output with some prefix+suffix1 or batch_id and then confirm that we are still able to execute of prefix+suffix2 on same batch id and getting correct output. +@pytest.mark.on_qaic +@pytest.mark.parametrize("model_name", test_models) +def test_simple_prefix_caching(model_name): + qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_name, continuous_batching=True) + qeff_model.compile( + prefill_seq_len=128, + ctx_len=256, + full_batch_size=2, + kv_cache_batch_size=4, + num_cores=14, + ) + + prefixes = ["Once upon a time ", "Once upon a time "] + suffixes1 = ["in a land far away", "there was a small village"] + suffixes2 = ["a little girl", "in a bustling city"] + + tokenizer = AutoTokenizer.from_pretrained(model_name) + + generator = TextGeneration(tokenizer=tokenizer, qpc_path=qeff_model.qpc_path, full_batch_size=2, ctx_len=256) + + prompts = [pref + suff for pref, suff in zip(prefixes, suffixes1)] + + # generation for batch_indices = 0, 1 + prompts_exec_info = generator.generate(prompts) + ############################## + # generation for batch_indices + ############################## + # Run prefill for indices 2, 3 with same prompts + out2, pos2, gen_len2 = generator._qaic_model.run_prefill( + prompts[0], generation_len=None, decode_batch_id=np.array(2, dtype=np.int64).reshape(1, 1) + ) + out3, pos3, gen_len3 = generator._qaic_model.run_prefill( + prompts[1], generation_len=None, decode_batch_id=np.array(3, dtype=np.int64).reshape(1, 1) + ) + + # Run decode for batch indices 2, 3 + decode_inputs = { + "input_ids": np.array([[out2["logits"].argmax(2)[0][0]], [out3["logits"].argmax(2)[0][0]]]), + "position_ids": np.array([[pos2[0][0]], [pos3[0][0]]]), + "batch_index": np.array([[2], [3]], dtype=np.int64), + } + + # Set logits placeholder for decode + logits_out_placeholder = np.zeros( + ( + generator._qaic_model.full_batch_size, + generator._qaic_model._decode_seq_len, + generator._qaic_model._vocab_size, + ), + dtype=np.float32, + ) + generator._qaic_model._session.set_buffers({"logits": logits_out_placeholder}) + + generation_outputs = [] + for i in range(gen_len2): + generation_outputs.append(decode_inputs["input_ids"]) + outputs = generator._qaic_model._session.run(decode_inputs) + logits = outputs["logits"] + if len(logits.shape) == 2: + logits = np.expand_dims(logits, 1) + next_token_id = logits.argmax(2) + + decode_inputs["input_ids"] = next_token_id + decode_inputs["position_ids"] += 1 + + assert np.all(generator._qaic_model.generated_ids[0, :gen_len2] == [int(val[0]) for val in generation_outputs]) + assert np.all(generator._qaic_model.generated_ids[1, :gen_len2] == [int(val[1]) for val in generation_outputs]) + + ############################## + # Now rerun with cached prefix on 0th index with prompt3 and use -1 for 1st index + ############################## + + nprompts = [pref + suff for pref, suff in zip(prefixes, suffixes2)] + + ## Prefill run on index 0 + prompt = nprompts[0] + inputs = tokenizer(prompt, return_tensors="np", padding=True) + position_ids = inputs["attention_mask"].sum(1, keepdims=True) + padded_len = inputs["input_ids"].shape[1] + num_chunks = -(padded_len // -generator._qaic_model._prefill_seq_len) + padded_len = num_chunks * generator._qaic_model._prefill_seq_len # Convert to a multiple of prompt_len + + # Initialize variables specific to request + # Calculate the max generation length. + max_gen_len = generator._qaic_model._ctx_len - position_ids.max() + + # Set the prefill logic buffer + logits_out_placeholder = np.zeros((1, 1, generator._qaic_model._vocab_size), dtype=np.float32) + generator._qaic_model._session.set_buffers({"logits": logits_out_placeholder}) + inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) + inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) + inputs.pop("token_type_ids", None) + inputs["batch_index"] = np.array([[0]], dtype=np.int64) + norm_outputs = generator._qaic_model._session.run(inputs) + inputs["input_ids"][:, :3] = inputs["input_ids"][:, 4:7] + inputs["input_ids"][:, 3:] = 50256 + inputs["position_ids"][:, :3] = inputs["position_ids"][:, 4:7] + inputs["position_ids"][:, 3:] = -1 + mod_outputs = generator._qaic_model._session.run(inputs) + assert (mod_outputs["logits"] == norm_outputs["logits"]).all() + decode_inputs = { + "input_ids": np.array([[mod_outputs["logits"].argmax(2)[0][0]], [0]]), + "position_ids": np.array([[position_ids[0][0]], [-1]]), + "batch_index": np.array([[0], [1]], dtype=np.int64), + } + + # Set logits placeholder for decode + logits_out_placeholder = np.zeros( + ( + generator._qaic_model.full_batch_size, + generator._qaic_model._decode_seq_len, + generator._qaic_model._vocab_size, + ), + dtype=np.float32, + ) + generator._qaic_model._session.set_buffers({"logits": logits_out_placeholder}) + + generation_outputs = [] + for i in range(max_gen_len): + generation_outputs.append(decode_inputs["input_ids"]) + outputs = generator._qaic_model._session.run(decode_inputs) + logits = outputs["logits"] + if len(logits.shape) == 2: + logits = np.expand_dims(logits, 1) + next_token_id = logits.argmax(2) + + decode_inputs["input_ids"] = next_token_id + decode_inputs["position_ids"][0][0] += 1 + + # TODO: add a check if this matches normal execution for same prompt + ############## + # Now run decode on 1st index again with mod_inputs and check if output is correct + ############## + decode_inputs = { + "input_ids": np.array([[0], [prompts_exec_info.generated_ids[1][0]]]), + "position_ids": np.array([[-1], [9]]), + "batch_index": np.array([[0], [1]], dtype=np.int64), + } + + # Set logits placeholder for decode + logits_out_placeholder = np.zeros( + ( + generator._qaic_model.full_batch_size, + generator._qaic_model._decode_seq_len, + generator._qaic_model._vocab_size, + ), + dtype=np.float32, + ) + generator._qaic_model._session.set_buffers({"logits": logits_out_placeholder}) + + generation_outputs_prefill_cached = [] + for i in range(max_gen_len): + generation_outputs_prefill_cached.append(decode_inputs["input_ids"]) + outputs = generator._qaic_model._session.run(decode_inputs) + logits = outputs["logits"] + if len(logits.shape) == 2: + logits = np.expand_dims(logits, 1) + next_token_id = logits.argmax(2) + + decode_inputs["input_ids"] = next_token_id + decode_inputs["position_ids"][1][0] += 1 + + assert np.all( + prompts_exec_info.generated_ids[1][:247] == [int(val[1]) for val in generation_outputs_prefill_cached][:247] + ) diff --git a/tests/transformers/spd/test_spd_inference.py b/tests/transformers/spd/test_spd_inference.py index 2e5f55cc7..18334e815 100644 --- a/tests/transformers/spd/test_spd_inference.py +++ b/tests/transformers/spd/test_spd_inference.py @@ -92,6 +92,7 @@ def split_dlm_bonus_token_inputs(dlm_decode_inputs): return bonus_token_inputs, dlm_decode_inputs +@pytest.mark.on_qaic @pytest.mark.parametrize( "prompt, num_speculative_tokens, prefill_seq_len, ctx_len, prefill_bsz, draft_model_name, target_model_name, full_batch_size", configs,