From fecc6739effb865b168aec1cb8eb7a553cf08465 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Mon, 2 Sep 2024 17:48:17 +0100 Subject: [PATCH] test (example/llm): refactor run tests. --- tests/brevitas_examples/test_llm.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index 7e786f4a4..55974af86 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -111,6 +111,15 @@ def default_run_args(request): return args +def run_test_models_run_args(args, model_with_ppl): + args.model = model_with_ppl.name + exp_float_ppl = model_with_ppl.float_ppl + use_fx = requires_fx(args) + if use_fx and not model_with_ppl.supports_fx: + pytest.xfail(f"{model_with_ppl.name} does not support FX") + float_ppl, quant_ppl, model = main(args) + + @pytest_cases.fixture( ids=[ "defaults", @@ -144,13 +153,7 @@ def toggle_run_args(default_run_args, request): @requires_pt_ge('2.2') def test_small_models_toggle_run_args(caplog, toggle_run_args, small_models_with_ppl): caplog.set_level(logging.INFO) - args = toggle_run_args - args.model = small_models_with_ppl.name - exp_float_ppl = small_models_with_ppl.float_ppl - use_fx = requires_fx(args) - if use_fx and not small_models_with_ppl.supports_fx: - pytest.xfail(f"{small_models_with_ppl.name} does not support FX") - float_ppl, quant_ppl, model = main(args) + run_test_models_run_args(toggle_run_args, small_models_with_ppl) @pytest_cases.fixture( @@ -172,13 +175,7 @@ def small_models_with_ppl_pt_ge_2_4(request): def test_small_models_toggle_run_args_pt_ge_2_4( caplog, toggle_run_args, small_models_with_ppl_pt_ge_2_4): caplog.set_level(logging.INFO) - args = toggle_run_args - args.model = small_models_with_ppl_pt_ge_2_4.name - exp_float_ppl = small_models_with_ppl_pt_ge_2_4.float_ppl - use_fx = requires_fx(args) - if use_fx and not small_models_with_ppl_pt_ge_2_4.supports_fx: - pytest.xfail(f"{small_models_with_ppl.name} does not support FX") - float_ppl, quant_ppl, model = main(args) + run_test_models_run_args(toggle_run_args, small_models_with_ppl_pt_ge_2_4) @pytest_cases.fixture(