Skip to content

Commit

Permalink
lint fix
Browse files Browse the repository at this point in the history
  • Loading branch information
eplatero97 committed Oct 24, 2024
1 parent fa058b7 commit 32ce801
Showing 1 changed file with 20 additions and 8 deletions.
28 changes: 20 additions & 8 deletions QEfficient/utils/generate_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,15 @@

class InputHandler:
def __init__(
self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, full_batch_size, num_logits_to_keep: Optional[int]
self,
batch_size,
tokenizer,
config,
prompt,
prompt_len,
ctx_len,
full_batch_size,
num_logits_to_keep: Optional[int],
):
"""
Initialization
Expand All @@ -28,8 +36,8 @@ def __init__(
:prompt_len (int): Prompt length for the model to compile.
:ctx_len (int): Maximum context length to compile the model.
:full_batch_size (int): Continuous batching batch size
:num_logits_to_keep (Optional[int]):
Calculate logits for the last valid `num_logits_to_keep` tokens.
:num_logits_to_keep (Optional[int]):
Calculate logits for the last valid `num_logits_to_keep` tokens.
Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
"""
Expand Down Expand Up @@ -116,12 +124,14 @@ def update_pytorch_inputs(self, inputs, pt_outputs):
if self.full_batch_size:
# Create CB inputs (make 1 batch index have proper inputs for decode pass)
batch_index = torch.arange(1).view(-1, 1)
batch_idx_input_ids = pt_outputs.logits.detach().argmax(2) # shape: [batch_size, num_logits_to_keep]
batch_idx_input_ids = pt_outputs.logits.detach().argmax(2) # shape: [batch_size, num_logits_to_keep]
input_ids = torch.full((self.full_batch_size, decode_len), self.tokenizer.pad_token_id)
input_ids[batch_index.view(-1)] = batch_idx_input_ids

position_ids = torch.full((self.full_batch_size, decode_len), 0)
batch_idx_position_ids = torch.arange(decode_len).view(1,-1) + (inputs["position_ids"].max(1, keepdim=True).values + 1)
batch_idx_position_ids = torch.arange(decode_len).view(1, -1) + (
inputs["position_ids"].max(1, keepdim=True).values + 1
)
position_ids[batch_index.view(-1)] = batch_idx_position_ids

updated_inputs["input_ids"] = input_ids
Expand All @@ -130,11 +140,13 @@ def update_pytorch_inputs(self, inputs, pt_outputs):

else:
if self.num_logits_to_keep is not None:
input_ids = pt_outputs["logits"].argmax(-1) # shape: [batch_size, num_logits_to_keep]
input_ids = pt_outputs["logits"].argmax(-1) # shape: [batch_size, num_logits_to_keep]
batch_size = input_ids.size(0)
position_ids = torch.arange(self.num_logits_to_keep).view(1, self.num_logits_to_keep).repeat(batch_size, 1)
position_ids = (
torch.arange(self.num_logits_to_keep).view(1, self.num_logits_to_keep).repeat(batch_size, 1)
)
else:
input_ids = pt_outputs["logits"].argmax(-1).reshape(-1, 1) # shape: [batch_size, 1]
input_ids = pt_outputs["logits"].argmax(-1).reshape(-1, 1) # shape: [batch_size, 1]
position_ids = inputs["position_ids"].max(1, keepdim=True).values + 1
updated_inputs["input_ids"] = input_ids
updated_inputs["position_ids"] = position_ids
Expand Down

0 comments on commit 32ce801

Please sign in to comment.