Skip to content

Commit

Permalink
added back mxint8 export and compilation
Browse files Browse the repository at this point in the history
Signed-off-by: eplatero <[email protected]>
  • Loading branch information
eplatero97 committed Oct 21, 2024
1 parent abd04e4 commit fa058b7
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 18 deletions.
6 changes: 4 additions & 2 deletions QEfficient/utils/generate_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,14 @@ def update_pytorch_inputs(self, inputs, pt_outputs):
if self.full_batch_size:
# Create CB inputs (make 1 batch index have proper inputs for decode pass)
batch_index = torch.arange(1).view(-1, 1)
batch_idx_input_ids = pt_outputs.logits.detach().argmax(2)
batch_idx_input_ids = pt_outputs.logits.detach().argmax(2) # shape: [batch_size, num_logits_to_keep]
input_ids = torch.full((self.full_batch_size, decode_len), self.tokenizer.pad_token_id)
input_ids[batch_index.view(-1)] = batch_idx_input_ids

position_ids = torch.full((self.full_batch_size, decode_len), 0)
batch_idx_position_ids = torch.arange(decode_len).view(1,-1) + (inputs["position_ids"].max(1, keepdim=True).values + 1)
position_ids[batch_index.view(-1)] = batch_idx_position_ids

updated_inputs["input_ids"] = input_ids
updated_inputs["position_ids"] = position_ids
updated_inputs["batch_index"] = torch.arange(self.full_batch_size).view(-1, 1)
Expand All @@ -132,7 +134,7 @@ def update_pytorch_inputs(self, inputs, pt_outputs):
batch_size = input_ids.size(0)
position_ids = torch.arange(self.num_logits_to_keep).view(1, self.num_logits_to_keep).repeat(batch_size, 1)
else:
input_ids = pt_outputs["logits"].argmax(-1).reshape(-1, 1)
input_ids = pt_outputs["logits"].argmax(-1).reshape(-1, 1) # shape: [batch_size, 1]
position_ids = inputs["position_ids"].max(1, keepdim=True).values + 1
updated_inputs["input_ids"] = input_ids
updated_inputs["position_ids"] = position_ids
Expand Down
32 changes: 16 additions & 16 deletions tests/spd/test_tlm_dlm_export_and_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,23 @@

configs = [
pytest.param(
[0], # device_group
2, # num_speculative_tokens
32, # prompt_len
128, # ctx_len
1, # prefill_bsz
8, # full_batch_size
"JackFram/llama-68m", # model_name
[0], # device_group
2, # num_speculative_tokens
32, # prompt_len
128, # ctx_len
1, # prefill_bsz
8, # full_batch_size
"JackFram/llama-68m", # model_name
id="CB llama",
),
pytest.param(
[0], # device_group
2, # num_speculative_tokens
32, # prompt_len
128, # ctx_len
1, # prefill_bsz
None, # full_batch_size
"JackFram/llama-68m", # model_name
[0], # device_group
2, # num_speculative_tokens
32, # prompt_len
128, # ctx_len
1, # prefill_bsz
None, # full_batch_size
"JackFram/llama-68m", # model_name
id="non-CB llama",
),
]
Expand Down Expand Up @@ -63,7 +63,7 @@ def test_llama_tlm_logit_dims(
prompt_len=prompt_len,
ctx_len=ctx_len,
mxfp6=True,
# mxint8=True,
mxint8=True,
full_batch_size=full_batch_size,
)

Expand Down Expand Up @@ -126,7 +126,7 @@ def test_llama_dlm_logit_dims(
prompt_len=prompt_len,
ctx_len=ctx_len,
mxfp6=True,
# mxint8=True,
mxint8=True,
full_batch_size=full_batch_size,
)

Expand Down

0 comments on commit fa058b7

Please sign in to comment.