Skip to content

Commit

Permalink
test (ex/llm): reorganise to prevent export issues
Browse files Browse the repository at this point in the history
  • Loading branch information
nickfraser committed Aug 23, 2024
1 parent 513cac0 commit a9d6361
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions tests/brevitas_examples/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
import torch

from brevitas import config
from brevitas_examples.llm.main import main
from brevitas_examples.llm.main import parse_args
from tests.marker import jit_disabled_for_export


Expand Down Expand Up @@ -94,6 +92,7 @@ def small_models_with_ppl(request):

@pytest.fixture()
def default_run_args(request):
from brevitas_examples.llm.main import parse_args
args = UpdatableNamespace(**vars(parse_args([])))
args.nsamples = 2
args.seqlen = 2
Expand Down Expand Up @@ -132,6 +131,7 @@ def toggle_run_args(default_run_args, request):

@pytest.mark.llm
def test_small_models_toggle_run_args(caplog, toggle_run_args, small_models_with_ppl):
from brevitas_examples.llm.main import main
caplog.set_level(logging.INFO)
args = toggle_run_args
args.model = small_models_with_ppl.name
Expand Down Expand Up @@ -176,6 +176,7 @@ def acc_args_and_acc(default_run_args, request):

@pytest.mark.llm
def test_small_models_acc(caplog, acc_args_and_acc):
from brevitas_examples.llm.main import main
caplog.set_level(logging.INFO)
args, exp_float_ppl, exp_quant_ppl = acc_args_and_acc
float_ppl, quant_ppl, model = main(args)
Expand Down Expand Up @@ -298,6 +299,7 @@ def layer_args(default_run_args, request):

@pytest.mark.llm
def test_small_models_quant_layer(caplog, layer_args):
from brevitas_examples.llm.main import main
caplog.set_level(logging.INFO)
args, exp_layer_types = layer_args
float_ppl, quant_ppl, model = main(args)
Expand Down Expand Up @@ -326,6 +328,7 @@ def onnx_export_args(default_run_args, request):
@pytest.mark.llm
@jit_disabled_for_export()
def test_small_models_onnx_export(caplog, onnx_export_args):
from brevitas_examples.llm.main import main
import onnx
caplog.set_level(logging.INFO)
args = onnx_export_args
Expand Down Expand Up @@ -358,6 +361,7 @@ def torch_export_args(default_run_args, request):
@pytest.mark.llm
@jit_disabled_for_export()
def test_small_models_torch_export(caplog, torch_export_args):
from brevitas_examples.llm.main import main
caplog.set_level(logging.INFO)
args = torch_export_args
float_ppl, quant_ppl, model = main(args)
Expand Down

0 comments on commit a9d6361

Please sign in to comment.