Skip to content

Commit

Permalink
[eplatero] Add support for exporting and compiling models for SpD
Browse files Browse the repository at this point in the history
(https://jira-dc.qualcomm.com/jira/browse/CLOUDPERF-43)
This change has been validated and posted on behalf of Erick Platero.

It adds support for generating a Target LM to run as a verifier model
by outputting all logits instead of just that of the last position for the
input sequence.

It also allows compiling the Target and Draft LMs with specializations
that support SpD

Usage:

TLM:
tlm = QEFFAutoModelForCausalLM.from_pretrained(<tlm-model-card>)
tlm.transform(num_speculative_tokens=<k>)
tlm.export_and_compile(<compiler-args>)

DLM:
dlm = QEFFAutoModelForCausalLM.from_pretrained(<dlm-model-card>)
dlm.transform(is_dlm=True)
dlm.export_and_compile(<compiler-args>)
  • Loading branch information
Apoorva Gokhale committed Sep 23, 2024
1 parent afb4645 commit dbc1712
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 19 deletions.
49 changes: 37 additions & 12 deletions QEfficient/compile/compile_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,47 @@


def create_and_dump_specializations(
batch_size: int, prompt_len: int, ctx_len: int, path: str, full_batch_size: Optional[int] = None
batch_size: int,
prompt_len: int,
ctx_len: int,
path: str,
is_dlm: bool,
full_batch_size: Optional[int] = None,
num_speculative_tokens: Optional[int] = None,
):
# Create specialization file.
specializations = {
"specializations": [
{
"batch_size": str(batch_size),
"seq_len": str(prompt_len),
"ctx_len": str(ctx_len),
},
{"batch_size": str(batch_size), "seq_len": "1", "ctx_len": str(ctx_len)},
]
}
# Create specialization cfgs
prefill_specialization = {"batch_size": str(batch_size), "seq_len": str(prompt_len), "ctx_len": str(ctx_len)}
if num_speculative_tokens is None:
decode_specialization = {
"batch_size": str(batch_size),
"seq_len": "1",
"ctx_len": str(ctx_len),
}
else:
decode_specialization = {
"batch_size": str(batch_size),
"seq_len": str(num_speculative_tokens + 1),
"ctx_len": str(ctx_len),
}
specialization_cfgs = [prefill_specialization, decode_specialization]
if is_dlm:
dlm_specialization = {
"batch_size": str(batch_size),
"seq_len": "2",
"ctx_len": str(ctx_len),
}
specialization_cfgs.append(dlm_specialization)

specializations = dict(specializations=specialization_cfgs)

# If continuous batching is enabled by proving full_batch_size we need to add FBS to the specialization file and update the batch size of decoder part to FBS
if full_batch_size is not None:
specializations["specializations"][0]["full_batch_size"] = str(full_batch_size)
specializations["specializations"][1]["full_batch_size"] = str(full_batch_size)
specializations["specializations"][1]["batch_size"] = str(full_batch_size)
if len(specializations["specializations"]) == 3:
specializations["specializations"][2]["batch_size"] = str(full_batch_size)
specializations["specializations"][2]["full_batch_size"] = str(full_batch_size)

# Dump
with open(path, "w") as file:
Expand Down Expand Up @@ -158,6 +181,8 @@ def compile(
ctx_len=ctx_len,
path=specialization_json_path,
full_batch_size=full_batch_size,
is_dlm=kwargs.get("is_dlm", None),
num_speculative_tokens=kwargs.get("num_speculative_tokens", None),
)

# Select the customIO config based on the mx flag.
Expand Down
1 change: 1 addition & 0 deletions QEfficient/exporter/export_hf_to_cloud_ai_100.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def export_kvstyle_transformed_model_to_onnx(
prompt_len=Constants.PROMPT_LEN,
ctx_len=seq_len,
full_batch_size=full_batch_size,
num_speculative_tokens=getattr(transformed_model, "num_speculative_tokens", None),
)

inputs = input_handler.prepare_pytorch_inputs()
Expand Down
18 changes: 18 additions & 0 deletions QEfficient/transformers/modeling_spd_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import Optional

import torch


def filter_hidden_states(
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
num_speculative_tokens: Optional[int],
) -> torch.Tensor:
"""filter hidden states based on whether this is a TLM SpD model"""
batch_indices = torch.arange(position_ids.shape[0])
if num_speculative_tokens is not None:
# all logits need to be computed
return hidden_states[batch_indices].squeeze(1)
# Cast to INT32 to avoid issue while running in ONNXRT
logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True)
return hidden_states[batch_indices.view(-1, 1), logit_index]
4 changes: 2 additions & 2 deletions QEfficient/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)

from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
from QEfficient.transformers.modeling_spd_utils import filter_hidden_states


class QEffLlamaRotaryEmbedding(LlamaRotaryEmbedding):
Expand Down Expand Up @@ -288,8 +289,7 @@ def forward(
)

# Cast to INT32 to avoid issue while running in ONNXRT
logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True)
hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index]
hidden_states = filter_hidden_states(outputs[0], position_ids, getattr(self, "num_speculative_tokens", None))
if self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
Expand Down
16 changes: 15 additions & 1 deletion QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ def transform(self, **kwargs):
"""
if self.is_transformed:
return

if self.full_batch_size is not None:
if KVCacheTransform in self._pytorch_transforms:
self._pytorch_transforms[self._pytorch_transforms.index(KVCacheTransform)] = CBTransform
Expand All @@ -188,6 +187,19 @@ def transform(self, **kwargs):
if isinstance(self.model.config.quantization_config, QEffGPTQConfig):
self._pytorch_transforms.insert(0, GPTQToMatmulNbitsTransform)

num_speculative_tokens = kwargs.get("num_speculative_tokens", None)
is_dlm = kwargs.get("is_dlm", False)
assert (
not isinstance(num_speculative_tokens, int)
) or not is_dlm, "number of speculative tokens are only to be specified for Target LM"
if num_speculative_tokens:
assert isinstance(num_speculative_tokens, int) and num_speculative_tokens > 0, (
"argument num_speculative_tokens" " should be of type integer and" " be positive if specified"
)
setattr(self.model, "num_speculative_tokens", num_speculative_tokens)
elif is_dlm:
setattr(self.model, "is_dlm", True)

for transform in self._pytorch_transforms:
transform.apply(self.model)
self.is_transformed = True
Expand Down Expand Up @@ -289,6 +301,8 @@ def compile(
mxfp6=mxfp6,
mxint8=mxint8,
full_batch_size=self.full_batch_size,
num_speculative_tokens=getattr(self.model, "num_speculative_tokens", None),
is_dlm=getattr(self.model, "is_dlm", False),
)
self.qpc_path = qpc_dir_path
return self.qpc_path
Expand Down
19 changes: 15 additions & 4 deletions QEfficient/utils/generate_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@


class InputHandler:
def __init__(self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, full_batch_size):
def __init__(
self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, full_batch_size, num_speculative_tokens
):
"""
Initialization
Expand All @@ -24,6 +26,7 @@ def __init__(self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, f
:prompt_len (int): Prompt length for the model to compile.
:ctx_len (int): Maximum context length to compile the model.
:full_batch_size (int): Continuous batching batch size
:num_speculative_tokens (Optional[int]): used to determine whether this is a TLM model or not
"""
# check and fix tokenizer viability
padding_check_and_fix(tokenizer)
Expand All @@ -32,6 +35,7 @@ def __init__(self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, f
self.prompt_len = prompt_len
self.ctx_len = ctx_len
self.full_batch_size = full_batch_size
self.num_speculative_tokens = num_speculative_tokens
self.n_layer = get_num_layers_from_config(config)
self.padding_shape = get_padding_shape_from_config(
config=config, batch_size=full_batch_size if full_batch_size else batch_size, seq_len=ctx_len
Expand Down Expand Up @@ -99,8 +103,10 @@ def update_pytorch_inputs(self, inputs, pt_outputs):
updated_inputs = {}
if self.full_batch_size:
batch_index = torch.arange(1).view(-1, 1)

input_ids = pt_outputs.logits.detach().argmax(2)
if self.num_speculative_tokens:
input_ids = pt_outputs.logits.detach()[:, -1].argmax(1, keepdim=True)
else:
input_ids = pt_outputs.logits.detach().argmax(2)
updated_inputs["input_ids"] = torch.full((self.full_batch_size, 1), self.tokenizer.pad_token_id)
updated_inputs["input_ids"][batch_index.view(-1)] = input_ids

Expand All @@ -111,7 +117,12 @@ def update_pytorch_inputs(self, inputs, pt_outputs):
updated_inputs["batch_index"] = torch.arange(self.full_batch_size).view(-1, 1)

else:
updated_inputs["input_ids"] = pt_outputs["logits"].argmax(-1).reshape(-1, 1)
if self.num_speculative_tokens:
# assume spec decoding logits
input_ids = pt_outputs["logits"][:, -1].argmax(-1).reshape(-1, 1)
else:
input_ids = pt_outputs["logits"].argmax(-1).reshape(-1, 1)
pt_outputs["input_ids"] = input_ids
updated_inputs["position_ids"] = inputs["position_ids"].max(1, keepdim=True).values + 1

updated_inputs["past_key_values"] = tuple(
Expand Down

0 comments on commit dbc1712

Please sign in to comment.