Skip to content

Commit

Permalink
incomplete development of SpD TLM validation step. posting this to sh…
Browse files Browse the repository at this point in the history
…ow potential solution of this missing unit test.

Signed-off-by: eplatero <[email protected]>
  • Loading branch information
eplatero97 committed Dec 10, 2024
1 parent 5f52d98 commit 1741312
Showing 1 changed file with 46 additions and 1 deletion.
47 changes: 46 additions & 1 deletion QEfficient/generation/text_generation_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class PerfMetrics:
decode_perf: float
total_perf: float
total_time: float
is_matching: Optional[bool]


@dataclass
Expand Down Expand Up @@ -836,6 +837,44 @@ def generate_decode_stream(self, decode_inputs, generation_len):
break
yield decode_inputs["input_ids"] # yield the last token

def is_spd_acceptance_rate_fully_matching(self):

batch_size = self.generated_ids.shape[0]
valid_batch_indices = list(range(batch_size))
inputs = dict(
input_ids = np.zeros((batch_size, self._decode_seq_len)),
position_ids = np.hstack([self.decode_pos_ids-self._decode_seq_len+i for i in range(self._decode_seq_len)]),
batch_index = np.arange(batch_size).reshape(-1,1)
)
i = 0
while True:
# Prepare inputs.
inputs["input_ids"] = self.generated_ids[valid_batch_indices, i:i+self._decode_seq_len]
inputs["position_ids"] = inputs["position_ids"][valid_batch_indices]+self._decode_seq_len
# Remove `valid_batch_indices` that pass ctx_len limit or have eos token.
passed_max_limit = (inputs["position_ids"] > self._ctx_len-1).any(1)
passed_indices = np.where(passed_max_limit)[0].tolist()
contains_eos = (inputs["input_ids"] == self.tokenizer.eos_token_id).any(1)
eos_indices = np.where(contains_eos)[0].tolist()
rm_indices = np.unique(np.concatenate((passed_indices, eos_indices))).tolist()
if rm_indices:
rm_indices = sorted(rm_indices)[::-1]
for rm_indice in rm_indices:
del valid_batch_indices[rm_indice]
if not valid_batch_indices:
break
# Get predictions and perform validation step.
outputs = self._session.run(inputs)
input_ids = outputs["logits"][valid_batch_indices].argmax(2) # shape: [len(valid_batch_indices), self._decode_seq_len]
matching = (input_ids == self.generated_ids[valid_batch_indices, i+1:i+1+self._decode_seq_len])
if not matching.all():
logger.critical(f"TLM model does not have a 100% acceptance rate with itself!")
return False
i += self._decode_seq_len

return True



class TextGeneration:
def __init__(
Expand Down Expand Up @@ -931,7 +970,13 @@ def _regular_model_execution(
prefill_time, decode_perf, total_perf, total_time = calculate_latency(
total_decode_tokens, loop_start, start, end
)
self._perf_metrics = PerfMetrics(prefill_time, decode_perf, total_perf, total_time)
# Find whether SpD acceptance rate is fully matching (100% acceptance rate).
is_matching = None
if self.is_tlm:
is_matching = self._qaic_model.is_spd_acceptance_rate_fully_matching()
if not is_matching:
logger.warning(f"SpD TLM model acceptance rate with itself is not 100%!")
self._perf_metrics = PerfMetrics(prefill_time, decode_perf, total_perf, total_time, is_matching)
return self._perf_metrics, generated_texts

def _continuous_batching_execution(
Expand Down

0 comments on commit 1741312

Please sign in to comment.