Skip to content

Commit

Permalink
[WIP] Run inference with CB/non-CB SpD Models (#155)
Browse files Browse the repository at this point in the history
* adding spd inference script by apoorva

Signed-off-by: eplatero <[email protected]>

* use pytest parametrize configs

Signed-off-by: eplatero <[email protected]>

* rm  function as it was causing some corruption when populating logits

Signed-off-by: eplatero <[email protected]>

* first draft of script getting 100% acceptance rate when using same TLM/DLM

Signed-off-by: eplatero <[email protected]>

* validation with full_batch_size=2 is passing

Signed-off-by: eplatero <[email protected]>

* solved bugs with batch_size>1 and num_spec_tokens>1. 1 final bug remaing

Signed-off-by: eplatero <[email protected]>

* resolved some bugs

Signed-off-by: eplatero <[email protected]>

* rm most debug logs

Signed-off-by: eplatero <[email protected]>

* fix bug when some samples get all accepted and others do not

Signed-off-by: eplatero <[email protected]>

* assert spd output matches vanilla dlm output

Signed-off-by: eplatero <[email protected]>

* linting

Signed-off-by: eplatero <[email protected]>

* added higher spec_len

Signed-off-by: eplatero <[email protected]>

---------

Signed-off-by: eplatero <[email protected]>
Co-authored-by: eplatero <[email protected]>
  • Loading branch information
quic-agokhale and eplatero97 authored Dec 18, 2024
1 parent 26e472e commit 1d7c624
Showing 1 changed file with 340 additions and 0 deletions.
340 changes: 340 additions & 0 deletions tests/transformers/spd/test_spd_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,340 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

from time import perf_counter
from typing import List, Optional

import numpy as np
import pytest
from transformers import AutoTokenizer

from QEfficient import QEFFAutoModelForCausalLM as AutoModelForCausalLM
from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.utils.constants import Constants
from QEfficient.utils.device_utils import get_available_device_id

configs = [
pytest.param(
Constants.INPUT_STR, # prompt
4, # num_speculative_tokens
32, # prefill_seq_len
128, # ctx_len
1, # prefill_bsz
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", # draft_model_name
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", # target_model_name
1, # full_batch_size
id="CB llama",
),
]


def run_prefill_on_draft_and_target(
tlm_session: QAICInferenceSession,
dlm_session: QAICInferenceSession,
inputs: dict,
prefill_seq_len: int,
slot_idx: int,
):
input_len = inputs.input_ids.shape[1]
num_chunks = input_len // prefill_seq_len
cache_index = np.array([[0]], np.int64)
batch_index = np.array([[slot_idx]], np.int64)
inputs["batch_index"] = batch_index

# Run chunked prefill
for i in range(num_chunks):
chunk_inputs = inputs.copy()
chunk_inputs["input_ids"] = inputs["input_ids"][:, cache_index[0, 0] : cache_index[0, 0] + prefill_seq_len]
chunk_inputs["position_ids"] = inputs["position_ids"][
:, cache_index[0, 0] : cache_index[0, 0] + prefill_seq_len
]

tlm_outputs = tlm_session.run(chunk_inputs)
_ = dlm_session.run(chunk_inputs)
cache_index += prefill_seq_len

tlm_logits = tlm_outputs["logits"]
return tlm_logits


def get_padded_input_len(input_len: int, prefill_seq_len: int, ctx_len: int):
"""return padded input length (must be factor of `prefill_seq_len`)
Args:
input_len (int): prompt length
prefill_seq_len (int): prefill sequence length
ctx_len (int): context length
Returns:
input_len_padded (int): padded input length
"""
num_chunks = -(input_len // -prefill_seq_len) # ceil divide without float
input_len_padded = num_chunks * prefill_seq_len # Convert input_len to a multiple of prefill_seq_len
assert (
input_len_padded <= ctx_len
), "input_len rounded to nearest prefill_seq_len multiple should be less than ctx_len"
return input_len_padded


def split_dlm_bonus_token_inputs(dlm_decode_inputs):
bonus_token_inputs = dict()
bonus, regular = np.hsplit(dlm_decode_inputs["input_ids"], 2)
bonus_token_inputs["input_ids"] = bonus
dlm_decode_inputs["input_ids"] = regular
bonus, regular = np.hsplit(dlm_decode_inputs["position_ids"], 2)
bonus_token_inputs["position_ids"] = bonus
dlm_decode_inputs["position_ids"] = regular
bonus_token_inputs["batch_index"] = dlm_decode_inputs["batch_index"]
return bonus_token_inputs, dlm_decode_inputs


@pytest.mark.parametrize(
"prompt, num_speculative_tokens, prefill_seq_len, ctx_len, prefill_bsz, draft_model_name, target_model_name, full_batch_size",
configs,
)
def test_spec_decode_inference(
prompt: List[str],
num_speculative_tokens: int,
prefill_seq_len: int,
ctx_len: int,
prefill_bsz: int,
draft_model_name: str,
target_model_name: str,
full_batch_size: Optional[int],
):
# get device group
device_group: List[int] = get_available_device_id()
if not device_group:
pytest.skip("No available devices to run model on Cloud AI 100")
# assumes dlm and tlm are compiled to the same prompt-chunk-size, context length and full_batch_size/batch-size
# get vocab size
tokenizer = AutoTokenizer.from_pretrained(target_model_name, padding_side="right")
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
vocab_size = len(tokenizer)

# export_and_compile tlm and dlm
continuous_batching = full_batch_size is not None
target_model = AutoModelForCausalLM.from_pretrained(
target_model_name, continuous_batching=continuous_batching, is_tlm=True
)
draft_model = AutoModelForCausalLM.from_pretrained(draft_model_name, continuous_batching=continuous_batching)

num_devices = len(device_group)
target_model_qpc_path: str = target_model.compile(
num_cores=11,
num_devices=num_devices,
prefill_seq_len=prefill_seq_len,
ctx_len=ctx_len,
aic_enable_depth_first=True,
full_batch_size=full_batch_size,
num_speculative_tokens=num_speculative_tokens,
)
draft_model_qpc_path: str = draft_model.compile(
num_cores=5,
prefill_seq_len=prefill_seq_len,
ctx_len=ctx_len,
aic_enable_depth_first=True,
full_batch_size=full_batch_size,
)
# init qaic session
target_model_session = QAICInferenceSession(target_model_qpc_path, device_ids=device_group)
draft_model_session = QAICInferenceSession(draft_model_qpc_path, device_ids=device_group)

# skip inputs/outputs buffers
target_model_session.skip_buffers(set([x for x in target_model_session.input_names if x.startswith("past_")]))
target_model_session.skip_buffers(
set([x for x in target_model_session.output_names if x.endswith("_RetainedState")])
)
draft_model_session.skip_buffers(set([x for x in draft_model_session.input_names if x.startswith("past_")]))
draft_model_session.skip_buffers(set([x for x in draft_model_session.output_names if x.endswith("_RetainedState")]))

is_cb = full_batch_size is not None
if not is_cb:
prompts = prompt * prefill_bsz
decode_batch_size = prefill_bsz
else:
prompts = prompt
decode_batch_size = full_batch_size
# tokenize the prompts
prompts_tokenized: List[dict] = []
for p in prompts:
input_len: int = tokenizer(p, return_tensors="np", padding=True).input_ids.shape[1]
input_len_padded: int = get_padded_input_len(input_len, prefill_seq_len, ctx_len)
p_tok: dict = tokenizer(p, return_tensors="np", padding="max_length", max_length=input_len_padded)
position_ids = np.where(p_tok.pop("attention_mask"), np.arange(input_len_padded), -1)
p_tok["position_ids"] = position_ids
prompts_tokenized.append(p_tok)
# create caches to hold generated ids and input prompt lengths
generated_ids = [[] for i in range(decode_batch_size)]
input_lengths = [0] * decode_batch_size
# run prefill on both draft and target models
dlm_decode_inputs = dict()
dlm_decode_inputs["position_ids"] = np.zeros((decode_batch_size, 1), np.int64)
dlm_decode_inputs["input_ids"] = np.full((decode_batch_size, 1), tokenizer.pad_token_id)
dlm_decode_inputs["batch_index"] = np.reshape(
np.array(np.arange(decode_batch_size), np.int64), (decode_batch_size, 1)
)
# mock input key "logits" to store the first batch of output logits
tlm_precode_inputs = dict(
input_ids=np.zeros((decode_batch_size, num_speculative_tokens + 1), dtype=np.int64),
position_ids=np.zeros((decode_batch_size, num_speculative_tokens + 1), dtype=np.int64),
batch_index=np.arange(decode_batch_size, dtype=np.int64).reshape(-1, 1),
)
max_gen_len = [ctx_len] * decode_batch_size
num_logits_to_keep = num_speculative_tokens + 1
# setup buffers
tlm_prefill_logits_ph = np.zeros((prefill_bsz, 1, vocab_size), dtype=np.float32)
dlm_prefill_logits_ph = np.zeros((prefill_bsz, 1, vocab_size), dtype=np.float32)
decode_logits_ph = np.zeros((decode_batch_size, 1, vocab_size), dtype=np.float32)
precode_logits_ph = np.zeros((decode_batch_size, num_logits_to_keep, vocab_size), dtype=np.float32)

target_model_session.set_buffers({"logits": tlm_prefill_logits_ph})
draft_model_session.set_buffers({"logits": dlm_prefill_logits_ph})
e2e_start = perf_counter()
ttfts = []
for bi in range(decode_batch_size):
# assumes that prefill queue will always be popped from the front
start = perf_counter()
tlm_logits = run_prefill_on_draft_and_target(
tlm_session=target_model_session,
dlm_session=draft_model_session,
inputs=prompts_tokenized[bi],
prefill_seq_len=prefill_seq_len,
slot_idx=bi,
)
ttft = perf_counter() - start
ttfts.append(ttft)
input_ids = tlm_logits.argmax(2).astype(np.int64)
generated_ids[bi].append(input_ids.item())
dlm_decode_inputs["input_ids"][bi, 0] = input_ids
tlm_precode_inputs["input_ids"][bi, 0] = input_ids.item()
input_len = prompts_tokenized[bi]["position_ids"].max(1).item() + 1
dlm_decode_inputs["position_ids"][bi, 0] = input_len
tlm_precode_inputs["position_ids"][bi] = np.arange(
input_len, input_len + num_speculative_tokens + 1, dtype=np.int64
)
# assumes that prefill queue will always be popped from the front
input_lengths[bi] = input_len
max_gen_len[bi] -= input_lengths[bi]
batch_ttft = perf_counter() - e2e_start

# set decode logits buffers
target_model_session.set_buffers({"logits": precode_logits_ph})
draft_model_session.set_buffers({"logits": decode_logits_ph})
# start decode phase
valid_batch_indices = np.full(decode_batch_size, True, dtype=bool)
all_accept = False
it = 0
decode_start = perf_counter()
mean_num_accepted_tokens = 0
all_accept = np.full(decode_batch_size, False, dtype=bool)
while True:
it += 1
# generate proposals from draft model
for k_ in range(num_speculative_tokens):
if all_accept.any():
# running decode one extra time in the first speculative iteration
# workaround to avoid the incorrect precode with 3-specialized multi-batch DLM
bonus_token_inputs, dlm_decode_inputs = split_dlm_bonus_token_inputs(dlm_decode_inputs)
_ = draft_model_session.run(bonus_token_inputs)
all_accept[:] = False
dlm_outputs = draft_model_session.run(dlm_decode_inputs)
input_ids = dlm_outputs["logits"].argmax(2)
tlm_precode_inputs["input_ids"][:, k_ + 1] = input_ids.flatten()
dlm_decode_inputs["input_ids"] = input_ids
dlm_decode_inputs["position_ids"][valid_batch_indices] += 1
# run precode on TLM to score the proposed tokens
tlm_outputs = target_model_session.run(tlm_precode_inputs)
target_logits = tlm_outputs["logits"]
# greedy sampling from target model
target_tokens = target_logits.argmax(-1)
# exact matching between draft and target tokens
draft_tokens = tlm_precode_inputs["input_ids"][:, 1:]
matching = draft_tokens == target_tokens[:, :-1] # shape: [decode_batch_size, num_speculative_tokens]
num_tokens_selected = matching.cumprod(axis=1).sum(axis=1) + 1 # shape: [decode_batch_size]
all_accept[valid_batch_indices] = num_tokens_selected[valid_batch_indices] == num_speculative_tokens + 1
mean_num_accepted_tokens += num_tokens_selected[valid_batch_indices].mean().item()
# append selected tokens to the generated_ids
tlm_precode_position_ids = tlm_precode_inputs["position_ids"] + num_tokens_selected.reshape(
decode_batch_size, 1
)
# tlm_precode_position_ids = tlm_precode_inputs["position_ids"] + num_tokens_selected.reshape(decode_batch_size,1)+1
for bi, valid in enumerate(valid_batch_indices):
if not valid:
continue
accepted_tokens = num_tokens_selected[bi]
num_tokens_to_append = min(accepted_tokens, max_gen_len[bi] - len(generated_ids[bi]))
generated_ids[bi].extend(target_tokens[bi, :num_tokens_to_append].tolist())
# position_ids > ctx_len-1 result in erronous output for logits at each seq_len of TLM
# (e.g., ctx_len=128 -> position_ids=[127,128,129] will give erronous output at each predicted token)
if len(generated_ids[bi]) >= max_gen_len[bi] or (tlm_precode_position_ids[bi] > ctx_len - 1).any():
valid_batch_indices[bi] = False
# check if all generations are done
if not valid_batch_indices.any():
break
# prepare decode inputs for next decode iteration
num_valid_batch_indices = valid_batch_indices.sum().item()
common_input_ids = target_tokens[valid_batch_indices, num_tokens_selected[valid_batch_indices] - 1].reshape(
num_valid_batch_indices, 1
)
common_position_ids = (
tlm_precode_inputs["position_ids"][
valid_batch_indices, num_tokens_selected[valid_batch_indices] - 1
].reshape(num_valid_batch_indices, 1)
+ 1
)
if all_accept.any():
# all_accept input_ids
input_ids = np.zeros((decode_batch_size, 2), dtype=np.int64)
last_spec_token_id = target_tokens[valid_batch_indices, -2].reshape(-1, 1)
input_ids[valid_batch_indices] = np.concatenate([last_spec_token_id, common_input_ids], axis=1)
dlm_decode_inputs["input_ids"] = input_ids
# all_accept position_ids
position_ids = np.full((decode_batch_size, 2), -1, dtype=np.int64)
last_spec_position_id = tlm_precode_inputs["position_ids"][valid_batch_indices, -1].reshape(-1, 1)
position_ids[valid_batch_indices] = np.concatenate([last_spec_position_id, common_position_ids], axis=1)
dlm_decode_inputs["position_ids"] = position_ids
else:
dlm_decode_inputs["input_ids"][valid_batch_indices] = common_input_ids
dlm_decode_inputs["position_ids"][valid_batch_indices] = common_position_ids
tlm_precode_inputs["input_ids"][valid_batch_indices, 0] = common_input_ids.flatten()
tlm_precode_inputs["position_ids"][valid_batch_indices] += num_tokens_selected[valid_batch_indices].reshape(
num_valid_batch_indices, 1
)
end = perf_counter()
decode_end = end - decode_start
e2e_end = end - e2e_start
mean_ttft = sum(ttfts) / len(ttfts)
generated_tokens_per_prompt = [len(gid) + 1 for gid in generated_ids]
decode_throughput = sum(generated_tokens_per_prompt) / decode_end
e2e_throughput = (sum(generated_tokens_per_prompt) + decode_batch_size) / e2e_end
batch_decode = tokenizer.batch_decode(generated_ids)
mean_num_accepted_tokens /= it
print(f"Avg TLM+DLM TTFT = {mean_ttft}")
print(f"Total TLM+DLM Batch TTFT = {batch_ttft}")
print(f"Decode Throughput = {decode_throughput}")
print(f"E2E Throughput = {e2e_throughput}")
print("Avg number of accepted tokens = ", mean_num_accepted_tokens)
print("Max generation len = ", max_gen_len)
print("Total Generated Tokens per Prompt: = ", generated_tokens_per_prompt)
for prompt, generation in zip(prompts, batch_decode):
print(f"{prompt=} {generation=}")
# validation check
assert mean_num_accepted_tokens == float(
num_speculative_tokens + 1
), f"mean number of accepted tokens is {mean_num_accepted_tokens} but should be {num_speculative_tokens+1}"
del target_model_session
del draft_model_session
generated_ids = np.asarray(generated_ids).flatten()
gen_len = generated_ids.shape[0]
exec_info = draft_model.generate(tokenizer, Constants.INPUT_STR, device_group)
cloud_ai_100_tokens = exec_info.generated_ids[0][
:gen_len
] # Because we always run for single input and single batch size
all_matching = np.array_equal(cloud_ai_100_tokens, generated_ids)
assert all_matching, "Tokens don't match for SpD output and vanilla DLM output."

0 comments on commit 1d7c624

Please sign in to comment.