diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index f1fe5565..3eecd7b3 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -37,6 +37,7 @@ class PerfMetrics: decode_perf: float total_perf: float total_time: float + is_matching: Optional[bool] @dataclass @@ -836,6 +837,44 @@ def generate_decode_stream(self, decode_inputs, generation_len): break yield decode_inputs["input_ids"] # yield the last token + def is_spd_acceptance_rate_fully_matching(self): + + batch_size = self.generated_ids.shape[0] + valid_batch_indices = list(range(batch_size)) + inputs = dict( + input_ids = np.zeros((batch_size, self._decode_seq_len)), + position_ids = np.hstack([self.decode_pos_ids-self._decode_seq_len+i for i in range(self._decode_seq_len)]), + batch_index = np.arange(batch_size).reshape(-1,1) + ) + i = 0 + while True: + # Prepare inputs. + inputs["input_ids"] = self.generated_ids[valid_batch_indices, i:i+self._decode_seq_len] + inputs["position_ids"] = inputs["position_ids"][valid_batch_indices]+self._decode_seq_len + # Remove `valid_batch_indices` that pass ctx_len limit or have eos token. + passed_max_limit = (inputs["position_ids"] > self._ctx_len-1).any(1) + passed_indices = np.where(passed_max_limit)[0].tolist() + contains_eos = (inputs["input_ids"] == self.tokenizer.eos_token_id).any(1) + eos_indices = np.where(contains_eos)[0].tolist() + rm_indices = np.unique(np.concatenate((passed_indices, eos_indices))).tolist() + if rm_indices: + rm_indices = sorted(rm_indices)[::-1] + for rm_indice in rm_indices: + del valid_batch_indices[rm_indice] + if not valid_batch_indices: + break + # Get predictions and perform validation step. + outputs = self._session.run(inputs) + input_ids = outputs["logits"][valid_batch_indices].argmax(2) # shape: [len(valid_batch_indices), self._decode_seq_len] + matching = (input_ids == self.generated_ids[valid_batch_indices, i+1:i+1+self._decode_seq_len]) + if not matching.all(): + logger.critical(f"TLM model does not have a 100% acceptance rate with itself!") + return False + i += self._decode_seq_len + + return True + + class TextGeneration: def __init__( @@ -931,7 +970,13 @@ def _regular_model_execution( prefill_time, decode_perf, total_perf, total_time = calculate_latency( total_decode_tokens, loop_start, start, end ) - self._perf_metrics = PerfMetrics(prefill_time, decode_perf, total_perf, total_time) + # Find whether SpD acceptance rate is fully matching (100% acceptance rate). + is_matching = None + if self.is_tlm: + is_matching = self._qaic_model.is_spd_acceptance_rate_fully_matching() + if not is_matching: + logger.warning(f"SpD TLM model acceptance rate with itself is not 100%!") + self._perf_metrics = PerfMetrics(prefill_time, decode_perf, total_perf, total_time, is_matching) return self._perf_metrics, generated_texts def _continuous_batching_execution(