Skip to content

Commit

Permalink
rm flag from non-test definition
Browse files Browse the repository at this point in the history
Signed-off-by: eplatero <[email protected]>
  • Loading branch information
eplatero97 committed Dec 5, 2024
1 parent 6202874 commit 0b85209
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tests/transformers/test_transformer_pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def compare_original_vs_kv_model_pt_outputs(original_val, kv_val, tolerance=1e-6


def run_kv_cache_transform_and_test(
hf_model, qaic_model_inputs, logits_tolerance=0.8, kv_cache=None, is_tlm=False,
hf_model, qaic_model_inputs, logits_tolerance=0.8, kv_cache=None,
):
hf_model.eval()
# Run original model
Expand All @@ -161,6 +161,7 @@ def run_kv_cache_transform_and_test(
original_model_outputs = hf_model(input_ids=input_ids, output_hidden_states=True)

# Apply transforms
is_tlm = "num_logits_to_keep" in qaic_model_inputs
hf_model = QEFFAutoModelForCausalLM(hf_model, is_tlm=is_tlm).model


Expand Down Expand Up @@ -290,7 +291,6 @@ def test_spd_transform(
qaic_model_inputs=qaic_model_inputs,
logits_tolerance=logits_tolerance,
kv_cache=kv_cache,
is_tlm=True,
)


Expand Down

0 comments on commit 0b85209

Please sign in to comment.