diff --git a/QEfficient/utils/generate_inputs.py b/QEfficient/utils/generate_inputs.py index d1f3dfc53..8a194060e 100644 --- a/QEfficient/utils/generate_inputs.py +++ b/QEfficient/utils/generate_inputs.py @@ -15,7 +15,15 @@ class InputHandler: def __init__( - self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, full_batch_size, num_logits_to_keep: Optional[int] + self, + batch_size, + tokenizer, + config, + prompt, + prompt_len, + ctx_len, + full_batch_size, + num_logits_to_keep: Optional[int], ): """ Initialization @@ -28,8 +36,8 @@ def __init__( :prompt_len (int): Prompt length for the model to compile. :ctx_len (int): Maximum context length to compile the model. :full_batch_size (int): Continuous batching batch size - :num_logits_to_keep (Optional[int]): - Calculate logits for the last valid `num_logits_to_keep` tokens. + :num_logits_to_keep (Optional[int]): + Calculate logits for the last valid `num_logits_to_keep` tokens. Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. """ @@ -116,12 +124,14 @@ def update_pytorch_inputs(self, inputs, pt_outputs): if self.full_batch_size: # Create CB inputs (make 1 batch index have proper inputs for decode pass) batch_index = torch.arange(1).view(-1, 1) - batch_idx_input_ids = pt_outputs.logits.detach().argmax(2) # shape: [batch_size, num_logits_to_keep] + batch_idx_input_ids = pt_outputs.logits.detach().argmax(2) # shape: [batch_size, num_logits_to_keep] input_ids = torch.full((self.full_batch_size, decode_len), self.tokenizer.pad_token_id) input_ids[batch_index.view(-1)] = batch_idx_input_ids position_ids = torch.full((self.full_batch_size, decode_len), 0) - batch_idx_position_ids = torch.arange(decode_len).view(1,-1) + (inputs["position_ids"].max(1, keepdim=True).values + 1) + batch_idx_position_ids = torch.arange(decode_len).view(1, -1) + ( + inputs["position_ids"].max(1, keepdim=True).values + 1 + ) position_ids[batch_index.view(-1)] = batch_idx_position_ids updated_inputs["input_ids"] = input_ids @@ -130,11 +140,13 @@ def update_pytorch_inputs(self, inputs, pt_outputs): else: if self.num_logits_to_keep is not None: - input_ids = pt_outputs["logits"].argmax(-1) # shape: [batch_size, num_logits_to_keep] + input_ids = pt_outputs["logits"].argmax(-1) # shape: [batch_size, num_logits_to_keep] batch_size = input_ids.size(0) - position_ids = torch.arange(self.num_logits_to_keep).view(1, self.num_logits_to_keep).repeat(batch_size, 1) + position_ids = ( + torch.arange(self.num_logits_to_keep).view(1, self.num_logits_to_keep).repeat(batch_size, 1) + ) else: - input_ids = pt_outputs["logits"].argmax(-1).reshape(-1, 1) # shape: [batch_size, 1] + input_ids = pt_outputs["logits"].argmax(-1).reshape(-1, 1) # shape: [batch_size, 1] position_ids = inputs["position_ids"].max(1, keepdim=True).values + 1 updated_inputs["input_ids"] = input_ids updated_inputs["position_ids"] = position_ids