-
Notifications
You must be signed in to change notification settings - Fork 38
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add script to run SpD inference on CB models
- Loading branch information
1 parent
51a94f4
commit b9fc5e9
Showing
1 changed file
with
333 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,333 @@ | ||
# ----------------------------------------------------------------------------- | ||
# | ||
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# | ||
# ----------------------------------------------------------------------------- | ||
|
||
from typing import List, Optional | ||
|
||
import numpy as np | ||
from transformers import AutoTokenizer | ||
|
||
from QEfficient import QEFFAutoModelForCausalLM as AutoModelForCausalLM | ||
from QEfficient.generation.cloud_infer import QAICInferenceSession | ||
|
||
|
||
def run_prefill_on_draft_and_target( | ||
tlm_session, dlm_session, prompt, prompt_len, ctx_len, prefill_batch_size, decode_batch_size, slot_idx | ||
): | ||
tlm_decode_start_input = dict() | ||
dlm_decode_start_input = dict() | ||
inputs = prompt | ||
input_len = prompt.input_ids.shape[1] | ||
num_chunks = -(input_len // -prompt_len) # ceil divide without float | ||
input_len = num_chunks * prompt_len # Convert input_len to a multiple of prompt_len | ||
assert input_len <= ctx_len, "input_len should be less than ctx_len" | ||
# pad the prompt tokens to match the input_len | ||
inputs = prompt | ||
# TODO need to store the attention mask and position ids for each batch element so that we can access them | ||
# at decode time | ||
inputs["attention_mask"] = np.concatenate( | ||
[inputs["attention_mask"].astype(bool) for j in range(decode_batch_size)], 0 | ||
) | ||
inputs["position_ids"] = (np.cumsum(inputs["attention_mask"][0:1], 1) - 1) * inputs["attention_mask"][0:1] | ||
|
||
inputs["position_ids"][not inputs["attention_mask"][0:1]] = -1 | ||
cache_index = np.array([[0]], np.int64) | ||
batch_index = np.array([[slot_idx]], np.int64) | ||
inputs["batch_index"] = batch_index | ||
|
||
# Run chunked prefill | ||
for i in range(num_chunks): | ||
chunk_inputs = inputs.copy() | ||
chunk_inputs["input_ids"] = inputs["input_ids"][:, cache_index[0, 0] : cache_index[0, 0] + prompt_len] | ||
chunk_inputs["position_ids"] = inputs["position_ids"][:, cache_index[0, 0] : cache_index[0, 0] + prompt_len] | ||
|
||
chunk_inputs.pop("attention_mask") | ||
tlm_outputs = tlm_session.run(chunk_inputs) | ||
dlm_outputs = dlm_session.run(chunk_inputs) | ||
cache_index += prompt_len | ||
|
||
tlm_logits = tlm_outputs["logits"] | ||
dlm_logits = dlm_outputs["logits"] | ||
|
||
if len(tlm_logits.shape) == 2: | ||
tlm_logits = np.expand_dims(tlm_logits, 1) | ||
if len(dlm_logits.shape) == 2: | ||
dlm_logits = np.expand_dims(dlm_logits, 1) | ||
|
||
tlm_decode_start_pos_id = inputs["attention_mask"][0:1].sum(1, keepdims=True) | ||
# valid tlm_prefill_logit is from the last valid position as a modulo of prompt chunk size | ||
tlm_last_valid_index_in_chunk = (tlm_decode_start_pos_id[0, 0] - 1) % prompt_len | ||
tlm_prefill_logit = tlm_logits[:, tlm_last_valid_index_in_chunk : tlm_last_valid_index_in_chunk + 1] | ||
tlm_decode_start_input_id = tlm_prefill_logit.argmax(2) | ||
dlm_decode_start_input_id = dlm_logits.argmax(2) | ||
dlm_decode_start_pos_id = inputs["attention_mask"][0:1].sum(1, keepdims=True) | ||
|
||
inputs.pop("attention_mask") | ||
|
||
tlm_decode_start_input = { | ||
"logits": tlm_prefill_logit, | ||
"input_ids": tlm_decode_start_input_id, | ||
"position_ids": tlm_decode_start_pos_id, | ||
"batch_index": batch_index, | ||
"input_len": tlm_decode_start_pos_id[0, 0], | ||
} | ||
dlm_decode_start_input = { | ||
"logits": dlm_logits, | ||
"input_ids": dlm_decode_start_input_id, | ||
"position_ids": dlm_decode_start_pos_id, | ||
"batch_index": batch_index, | ||
"input_len": tlm_decode_start_pos_id[0, 0], | ||
} | ||
|
||
return tlm_decode_start_input, dlm_decode_start_input | ||
|
||
|
||
def get_padded_input_len(input_len, prompt_len, ctx_len): | ||
num_chunks = -(input_len // -prompt_len) # ceil divide without float | ||
input_len_padded = num_chunks * prompt_len # Convert input_len to a multiple of prompt_len | ||
assert input_len_padded <= ctx_len, "input_len rounded to nearest prompt_len multiple should be less than ctx_len" | ||
return input_len_padded | ||
|
||
|
||
def populate_inputs(source, dest, index=None): | ||
for k, v in dest.items(): | ||
if k == "batch_index": | ||
continue | ||
if index is None: | ||
# during decode | ||
dest[k] = source[k] | ||
else: | ||
# during prefill with bs=1 | ||
dest[k][index] = source[k] | ||
|
||
|
||
def test_spec_decode_inference( | ||
prompt: List[str], | ||
device_group: List[int], | ||
num_speculative_tokens: int, | ||
prompt_len: int, | ||
ctx_len: int, | ||
prefill_bsz: int, | ||
draft_model_name: str, | ||
target_model_name: str, | ||
full_batch_size: Optional[int] = None, | ||
): | ||
# assumes dlm and tlm are compiled to the same prompt-chunk-size, context length and full_batch_size/batch-size | ||
# get vocab size | ||
tokenizer = AutoTokenizer.from_pretrained(target_model_name) | ||
vocab_size = len(tokenizer) | ||
|
||
# export_and_compile tlm and dlm | ||
target_model = AutoModelForCausalLM.from_pretrained( | ||
target_model_name, num_speculative_tokens=num_speculative_tokens | ||
) | ||
draft_model = AutoModelForCausalLM.from_pretrained(draft_model_name, is_dlm=True) | ||
|
||
target_model_qpc_path: str = target_model.export_and_compile( | ||
num_cores=11, | ||
device_group=device_group, | ||
batch_size=prefill_bsz, | ||
prompt_len=prompt_len, | ||
ctx_len=ctx_len, | ||
mxfp6=True, | ||
mxint8=True, | ||
full_batch_size=full_batch_size, | ||
) | ||
draft_model_qpc_path: str = draft_model.export_and_compile( | ||
num_cores=5, | ||
device_group=device_group, | ||
batch_size=prefill_bsz, | ||
prompt_len=prompt_len, | ||
ctx_len=ctx_len, | ||
mxfp6=True, | ||
mxint8=True, | ||
full_batch_size=full_batch_size, | ||
) | ||
|
||
# init qaic session | ||
target_model_session = QAICInferenceSession(target_model_qpc_path, device_ids=device_group) | ||
draft_model_session = QAICInferenceSession(draft_model_qpc_path, device_ids=device_group) | ||
|
||
# skip inputs/outputs buffers | ||
target_model_session.skip_buffers(set([x for x in target_model_session.input_names if x.startswith("past_")])) | ||
target_model_session.skip_buffers( | ||
set([x for x in target_model_session.output_names if x.endswith("_RetainedState")]) | ||
) | ||
draft_model_session.skip_buffers(set([x for x in draft_model_session.input_names if x.startswith("past_")])) | ||
draft_model_session.skip_buffers(set([x for x in draft_model_session.output_names if x.endswith("_RetainedState")])) | ||
|
||
is_cb = full_batch_size is not None | ||
if not is_cb: | ||
prompts = prompt * prefill_bsz | ||
decode_batch_size = prefill_bsz | ||
else: | ||
prompts = prompt | ||
decode_batch_size = full_batch_size | ||
# tokenize the prompts | ||
prompts_tokenized = [] | ||
for p in prompts: | ||
input_len = tokenizer(p, return_tensors="np", padding=True).input_ids.shape[1] | ||
input_len_padded = get_padded_input_len(input_len, prompt_len, ctx_len) | ||
p_tok = tokenizer(p, return_tensors="np", padding="max_length", max_length=input_len_padded) | ||
prompts_tokenized.append(p_tok) | ||
generated_ids = [[] for i in range(decode_batch_size)] | ||
input_lengths = [0] * decode_batch_size | ||
# run prefill on both draft and target models | ||
dlm_decode_inputs = dict() | ||
dlm_decode_inputs["position_ids"] = np.zeros((decode_batch_size, 1), np.int64) | ||
dlm_decode_inputs["input_ids"] = np.full((decode_batch_size, 1), tokenizer.pad_token_id) | ||
dlm_decode_inputs["batch_index"] = np.reshape( | ||
np.array(np.arange(decode_batch_size), np.int64), (decode_batch_size, 1) | ||
) | ||
# mock input key "logits" to store the first batch of output logits | ||
dlm_decode_inputs["logits"] = np.full((decode_batch_size, 1, vocab_size), 0) | ||
tlm_precode_inputs = dict(dlm_decode_inputs) | ||
is_prefill = True | ||
generation_done = False | ||
max_gen_len = [ctx_len] * decode_batch_size | ||
all_accept = np.full((decode_batch_size, num_speculative_tokens), False, dtype=np.bool) | ||
tlm_prefill_logits_ph = np.zeros((prefill_bsz, prompt_len, vocab_size), dtype=np.float32) | ||
dlm_prefill_logits_ph = np.zeros((prefill_bsz, 1, vocab_size), dtype=np.float32) | ||
decode_logits_ph = np.zeros((decode_batch_size, 1, vocab_size), dtype=np.float32) | ||
precode_logits_ph = np.zeros((decode_batch_size, num_speculative_tokens + 1, vocab_size), dtype=np.float32) | ||
|
||
target_model_session.set_buffers({"logits": tlm_prefill_logits_ph}) | ||
draft_model_session.set_buffers({"logits": dlm_prefill_logits_ph}) | ||
for bi in range(decode_batch_size): | ||
# assumes that prefill queue will always be popped from the front | ||
tlm_prefill_output, dlm_prefill_output = run_prefill_on_draft_and_target( | ||
tlm_session=target_model_session, | ||
dlm_session=draft_model_session, | ||
prompt=prompts_tokenized[bi], | ||
prompt_len=prompt_len, | ||
ctx_len=ctx_len, | ||
prefill_batch_size=prefill_bsz, | ||
decode_batch_size=decode_batch_size, | ||
slot_idx=bi, | ||
) | ||
# this way, we will directly get the updated full batch input dict to run decode | ||
populate_inputs(dlm_prefill_output, dlm_decode_inputs, bi) | ||
populate_inputs(tlm_prefill_output, tlm_precode_inputs, bi) | ||
# assumes that prefill queue will always be popped from the front | ||
input_lengths[bi] = tlm_prefill_output["input_len"] | ||
max_gen_len[bi] -= input_lengths[bi] | ||
|
||
target_model_session.set_buffers({"logits": precode_logits_ph}) | ||
draft_model_session.set_buffers({"logits": decode_logits_ph}) | ||
while not generation_done: | ||
# compute the processed context length before each iteration to prepare the position id inputs | ||
processed_context = [len(generated_ids[j]) + input_lengths[j] for j in range(decode_batch_size)] | ||
# generate proposals from draft model | ||
if is_prefill: | ||
draft_logits = [dlm_decode_inputs.pop("logits")] | ||
target_logits = [tlm_precode_inputs.pop("logits")] | ||
else: | ||
if np.any(all_accept): | ||
input_ids = [] | ||
position_ids = [] | ||
for bi in range(decode_batch_size): | ||
if all_accept[bi]: | ||
# both last DLM token and bonus TLM token to be passed as input to DLM | ||
input_ids.append([generated_ids[bi][-2], generated_ids[bi][-1]]) | ||
position_ids.append([processed_context[bi] - 2, processed_context[bi] - 1]) | ||
else: | ||
# only the correct token from TLM from previous iteration and the pad_token as a dummy | ||
input_ids.append([generated_ids[bi][-1], tokenizer.pad_token_id]) | ||
position_ids.append([processed_context[bi] - 1, -1]) | ||
dlm_decode_inputs["input_ids"] = np.array(input_ids) | ||
dlm_decode_inputs["position_ids"] = np.array(position_ids) | ||
else: | ||
dlm_decode_inputs["input_ids"] = np.array([gid[-1] for gid in generated_ids], dtype=np.int64).reshape( | ||
(decode_batch_size, 1) | ||
) | ||
dlm_decode_inputs["position_ids"] = np.array( | ||
[(pc - 1) for pc in processed_context], dtype=np.int64 | ||
).reshape((decode_batch_size, 1)) | ||
# prepare the inputs for the dlm speculation | ||
# TODO in case of even one of the batch having all_accept, we have to use the seqlen=2 specialization | ||
# hence need to have dummy -1 position id for other sequences. | ||
# dlm_decode_inputs["position_ids"] = len(generated_ids per batch) | ||
# dlm_decode_inputs["input_ids"] = (last gen dlm token) + last true token from TLM | ||
for k_ in range(num_speculative_tokens): | ||
dlm_outputs = draft_model_session.run(dlm_decode_inputs) | ||
draft_logits.append(dlm_outputs["logits"]) | ||
dlm_decode_inputs["input_ids"] = dlm_outputs["logits"].argmax(-1) | ||
dlm_decode_inputs["position_ids"] = dlm_decode_inputs["position_ids"][:, -1:] + 1 | ||
|
||
draft_logits = np.array(draft_logits).squeeze(2).transpose((1, 0, 2)) | ||
# greedy sampling from draft model | ||
draft_tokens = draft_logits.argmax(-1) | ||
|
||
# construct precode inputs | ||
tlm_precode_inputs["input_ids"] = draft_tokens | ||
if not is_prefill: | ||
last_genid = np.array([gid[-1] for gid in generated_ids], dtype=np.int64).reshape(decode_batch_size, 1) | ||
tlm_precode_inputs["input_ids"] = np.concatenate((last_genid, tlm_precode_inputs["input_ids"]), axis=1) | ||
# in case of general precode, first token in input sequence is = last generated TLM token (kv cache backfill) | ||
tlm_precode_inputs["position_ids"] = np.array( | ||
[np.arange(start=pc - 1, stop=pc + num_speculative_tokens) for pc in processed_context], dtype=np.int64 | ||
) | ||
else: | ||
# in case of just first precode, we are feeding in all new positions | ||
tlm_precode_inputs["position_ids"] = np.array( | ||
[np.arange(start=pc, stop=pc + num_speculative_tokens + 1) for pc in processed_context], dtype=np.int64 | ||
) | ||
|
||
# run precode on TLM to score the proposed tokens | ||
tlm_outputs = target_model_session.run(tlm_precode_inputs) | ||
target_precode_logits = tlm_outputs["logits"] | ||
if is_prefill: | ||
target_logits = np.concatenate((target_logits[0], target_precode_logits), axis=1) | ||
# stack the prefill output logit and precode logits into a single tensor | ||
else: | ||
target_logits = target_precode_logits | ||
# greedy sampling from target model | ||
target_tokens = target_logits.argmax(-1) | ||
# exact matching between draft and target tokens | ||
matching = draft_tokens == target_tokens[:, :-1] | ||
num_tokens_selected = np.argmin(matching, axis=1) | ||
all_accept = matching[np.arange(decode_batch_size), num_tokens_selected] | ||
num_tokens_selected = np.where(all_accept, matching.shape[1], num_tokens_selected) | ||
|
||
# append selected tokens to the generated_ids | ||
for bi in range(decode_batch_size): | ||
if len(generated_ids[bi]) >= max_gen_len[bi]: | ||
continue | ||
num_tokens_to_append = min(num_tokens_selected[bi], max_gen_len[bi] - len(generated_ids[bi])) | ||
generated_ids[bi] += list(draft_tokens[bi, :num_tokens_to_append]) | ||
# append bonus/corrected token where applicable | ||
for bi in range(decode_batch_size): | ||
if len(generated_ids[bi]) >= max_gen_len[bi]: | ||
continue | ||
if all_accept[bi]: | ||
# bonus token | ||
generated_ids[bi].append(target_tokens[bi, -1]) | ||
else: | ||
# correct token | ||
generated_ids[bi].append(target_tokens[bi, num_tokens_selected[bi]]) | ||
generation_done = True | ||
for bi in range(decode_batch_size): | ||
if len(generated_ids[bi]) < max_gen_len[bi]: | ||
generation_done = False | ||
is_prefill = False | ||
draft_logits = [] | ||
target_logits = [] | ||
print("max generation len = ", max_gen_len) | ||
print("actual generation len = ", [len(gid) for gid in generated_ids]) | ||
print(tokenizer.batch_decode(generated_ids)) | ||
|
||
|
||
test_spec_decode_inference( | ||
["My name is", "Hello", "Hi", "My name is"], | ||
[0], | ||
5, | ||
32, | ||
128, | ||
1, | ||
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", | ||
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", | ||
4, | ||
) |