From dbc17129659ea35d0de79088a50e1f95c762056a Mon Sep 17 00:00:00 2001 From: Apoorva Gokhale Date: Mon, 23 Sep 2024 14:42:32 -0700 Subject: [PATCH] [eplatero] Add support for exporting and compiling models for SpD (https://jira-dc.qualcomm.com/jira/browse/CLOUDPERF-43) This change has been validated and posted on behalf of Erick Platero. It adds support for generating a Target LM to run as a verifier model by outputting all logits instead of just that of the last position for the input sequence. It also allows compiling the Target and Draft LMs with specializations that support SpD Usage: TLM: tlm = QEFFAutoModelForCausalLM.from_pretrained() tlm.transform(num_speculative_tokens=) tlm.export_and_compile() DLM: dlm = QEFFAutoModelForCausalLM.from_pretrained() dlm.transform(is_dlm=True) dlm.export_and_compile() --- QEfficient/compile/compile_helper.py | 49 ++++++++++++++----- .../exporter/export_hf_to_cloud_ai_100.py | 1 + QEfficient/transformers/modeling_spd_utils.py | 18 +++++++ .../models/llama/modeling_llama.py | 4 +- .../transformers/models/modeling_auto.py | 16 +++++- QEfficient/utils/generate_inputs.py | 19 +++++-- 6 files changed, 88 insertions(+), 19 deletions(-) create mode 100644 QEfficient/transformers/modeling_spd_utils.py diff --git a/QEfficient/compile/compile_helper.py b/QEfficient/compile/compile_helper.py index f4882efc0..0fe3718c2 100644 --- a/QEfficient/compile/compile_helper.py +++ b/QEfficient/compile/compile_helper.py @@ -15,24 +15,47 @@ def create_and_dump_specializations( - batch_size: int, prompt_len: int, ctx_len: int, path: str, full_batch_size: Optional[int] = None + batch_size: int, + prompt_len: int, + ctx_len: int, + path: str, + is_dlm: bool, + full_batch_size: Optional[int] = None, + num_speculative_tokens: Optional[int] = None, ): - # Create specialization file. - specializations = { - "specializations": [ - { - "batch_size": str(batch_size), - "seq_len": str(prompt_len), - "ctx_len": str(ctx_len), - }, - {"batch_size": str(batch_size), "seq_len": "1", "ctx_len": str(ctx_len)}, - ] - } + # Create specialization cfgs + prefill_specialization = {"batch_size": str(batch_size), "seq_len": str(prompt_len), "ctx_len": str(ctx_len)} + if num_speculative_tokens is None: + decode_specialization = { + "batch_size": str(batch_size), + "seq_len": "1", + "ctx_len": str(ctx_len), + } + else: + decode_specialization = { + "batch_size": str(batch_size), + "seq_len": str(num_speculative_tokens + 1), + "ctx_len": str(ctx_len), + } + specialization_cfgs = [prefill_specialization, decode_specialization] + if is_dlm: + dlm_specialization = { + "batch_size": str(batch_size), + "seq_len": "2", + "ctx_len": str(ctx_len), + } + specialization_cfgs.append(dlm_specialization) + + specializations = dict(specializations=specialization_cfgs) + # If continuous batching is enabled by proving full_batch_size we need to add FBS to the specialization file and update the batch size of decoder part to FBS if full_batch_size is not None: specializations["specializations"][0]["full_batch_size"] = str(full_batch_size) specializations["specializations"][1]["full_batch_size"] = str(full_batch_size) specializations["specializations"][1]["batch_size"] = str(full_batch_size) + if len(specializations["specializations"]) == 3: + specializations["specializations"][2]["batch_size"] = str(full_batch_size) + specializations["specializations"][2]["full_batch_size"] = str(full_batch_size) # Dump with open(path, "w") as file: @@ -158,6 +181,8 @@ def compile( ctx_len=ctx_len, path=specialization_json_path, full_batch_size=full_batch_size, + is_dlm=kwargs.get("is_dlm", None), + num_speculative_tokens=kwargs.get("num_speculative_tokens", None), ) # Select the customIO config based on the mx flag. diff --git a/QEfficient/exporter/export_hf_to_cloud_ai_100.py b/QEfficient/exporter/export_hf_to_cloud_ai_100.py index 706d14107..037040934 100644 --- a/QEfficient/exporter/export_hf_to_cloud_ai_100.py +++ b/QEfficient/exporter/export_hf_to_cloud_ai_100.py @@ -213,6 +213,7 @@ def export_kvstyle_transformed_model_to_onnx( prompt_len=Constants.PROMPT_LEN, ctx_len=seq_len, full_batch_size=full_batch_size, + num_speculative_tokens=getattr(transformed_model, "num_speculative_tokens", None), ) inputs = input_handler.prepare_pytorch_inputs() diff --git a/QEfficient/transformers/modeling_spd_utils.py b/QEfficient/transformers/modeling_spd_utils.py new file mode 100644 index 000000000..525d9e01b --- /dev/null +++ b/QEfficient/transformers/modeling_spd_utils.py @@ -0,0 +1,18 @@ +from typing import Optional + +import torch + + +def filter_hidden_states( + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + num_speculative_tokens: Optional[int], +) -> torch.Tensor: + """filter hidden states based on whether this is a TLM SpD model""" + batch_indices = torch.arange(position_ids.shape[0]) + if num_speculative_tokens is not None: + # all logits need to be computed + return hidden_states[batch_indices].squeeze(1) + # Cast to INT32 to avoid issue while running in ONNXRT + logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) + return hidden_states[batch_indices.view(-1, 1), logit_index] diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index 5f4aa2e55..fe9af32f6 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -31,6 +31,7 @@ ) from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.transformers.modeling_spd_utils import filter_hidden_states class QEffLlamaRotaryEmbedding(LlamaRotaryEmbedding): @@ -288,8 +289,7 @@ def forward( ) # Cast to INT32 to avoid issue while running in ONNXRT - logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) - hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] + hidden_states = filter_hidden_states(outputs[0], position_ids, getattr(self, "num_speculative_tokens", None)) if self.config.pretraining_tp > 1: lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 5cd058bea..75cd1be74 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -168,7 +168,6 @@ def transform(self, **kwargs): """ if self.is_transformed: return - if self.full_batch_size is not None: if KVCacheTransform in self._pytorch_transforms: self._pytorch_transforms[self._pytorch_transforms.index(KVCacheTransform)] = CBTransform @@ -188,6 +187,19 @@ def transform(self, **kwargs): if isinstance(self.model.config.quantization_config, QEffGPTQConfig): self._pytorch_transforms.insert(0, GPTQToMatmulNbitsTransform) + num_speculative_tokens = kwargs.get("num_speculative_tokens", None) + is_dlm = kwargs.get("is_dlm", False) + assert ( + not isinstance(num_speculative_tokens, int) + ) or not is_dlm, "number of speculative tokens are only to be specified for Target LM" + if num_speculative_tokens: + assert isinstance(num_speculative_tokens, int) and num_speculative_tokens > 0, ( + "argument num_speculative_tokens" " should be of type integer and" " be positive if specified" + ) + setattr(self.model, "num_speculative_tokens", num_speculative_tokens) + elif is_dlm: + setattr(self.model, "is_dlm", True) + for transform in self._pytorch_transforms: transform.apply(self.model) self.is_transformed = True @@ -289,6 +301,8 @@ def compile( mxfp6=mxfp6, mxint8=mxint8, full_batch_size=self.full_batch_size, + num_speculative_tokens=getattr(self.model, "num_speculative_tokens", None), + is_dlm=getattr(self.model, "is_dlm", False), ) self.qpc_path = qpc_dir_path return self.qpc_path diff --git a/QEfficient/utils/generate_inputs.py b/QEfficient/utils/generate_inputs.py index 252d445e2..8b9183ecd 100644 --- a/QEfficient/utils/generate_inputs.py +++ b/QEfficient/utils/generate_inputs.py @@ -12,7 +12,9 @@ class InputHandler: - def __init__(self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, full_batch_size): + def __init__( + self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, full_batch_size, num_speculative_tokens + ): """ Initialization @@ -24,6 +26,7 @@ def __init__(self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, f :prompt_len (int): Prompt length for the model to compile. :ctx_len (int): Maximum context length to compile the model. :full_batch_size (int): Continuous batching batch size + :num_speculative_tokens (Optional[int]): used to determine whether this is a TLM model or not """ # check and fix tokenizer viability padding_check_and_fix(tokenizer) @@ -32,6 +35,7 @@ def __init__(self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, f self.prompt_len = prompt_len self.ctx_len = ctx_len self.full_batch_size = full_batch_size + self.num_speculative_tokens = num_speculative_tokens self.n_layer = get_num_layers_from_config(config) self.padding_shape = get_padding_shape_from_config( config=config, batch_size=full_batch_size if full_batch_size else batch_size, seq_len=ctx_len @@ -99,8 +103,10 @@ def update_pytorch_inputs(self, inputs, pt_outputs): updated_inputs = {} if self.full_batch_size: batch_index = torch.arange(1).view(-1, 1) - - input_ids = pt_outputs.logits.detach().argmax(2) + if self.num_speculative_tokens: + input_ids = pt_outputs.logits.detach()[:, -1].argmax(1, keepdim=True) + else: + input_ids = pt_outputs.logits.detach().argmax(2) updated_inputs["input_ids"] = torch.full((self.full_batch_size, 1), self.tokenizer.pad_token_id) updated_inputs["input_ids"][batch_index.view(-1)] = input_ids @@ -111,7 +117,12 @@ def update_pytorch_inputs(self, inputs, pt_outputs): updated_inputs["batch_index"] = torch.arange(self.full_batch_size).view(-1, 1) else: - updated_inputs["input_ids"] = pt_outputs["logits"].argmax(-1).reshape(-1, 1) + if self.num_speculative_tokens: + # assume spec decoding logits + input_ids = pt_outputs["logits"][:, -1].argmax(-1).reshape(-1, 1) + else: + input_ids = pt_outputs["logits"].argmax(-1).reshape(-1, 1) + pt_outputs["input_ids"] = input_ids updated_inputs["position_ids"] = inputs["position_ids"].max(1, keepdim=True).values + 1 updated_inputs["past_key_values"] = tuple(