Skip to content

Commit

Permalink
test (example/llm): Added tests to ensure all args to main are also…
Browse files Browse the repository at this point in the history
… in `parse_args`
  • Loading branch information
nickfraser committed Sep 10, 2024
1 parent 0e762f7 commit 6e9283a
Showing 1 changed file with 21 additions and 7 deletions.
28 changes: 21 additions & 7 deletions tests/brevitas_examples/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,20 @@ def allexact(x, y):
return np.allclose(x, y, rtol=0.0, atol=0.0, equal_nan=False)


# Check that all args in args are used
def validate_args(args):
a = vars(args)
da = vars(parse_args([]))
for k in a.keys():
assert k in da.keys(), f"Key {k} does not seem to be a valid argument for `main`"


def validate_args_and_run_main(args):
validate_args(args)
float_ppl, quant_ppl, model = main(args)
return float_ppl, quant_ppl, model


def assert_layer_types(model, exp_layer_types):
for key, string in exp_layer_types.items():
matched = False
Expand Down Expand Up @@ -118,7 +132,7 @@ def run_test_models_run_args(args, model_with_ppl):
use_fx = requires_fx(args)
if use_fx and not model_with_ppl.supports_fx:
pytest.xfail(f"{model_with_ppl.name} does not support FX")
float_ppl, quant_ppl, model = main(args)
float_ppl, quant_ppl, model = validate_args_and_run_main(args)


@pytest_cases.fixture(
Expand Down Expand Up @@ -212,7 +226,7 @@ def acc_args_and_acc(default_run_args, request):
def test_small_models_acc(caplog, acc_args_and_acc):
caplog.set_level(logging.INFO)
args, exp_float_ppl, exp_quant_ppl = acc_args_and_acc
float_ppl, quant_ppl, model = main(args)
float_ppl, quant_ppl, model = validate_args_and_run_main(args)
float_ppl = float_ppl.detach().cpu().numpy()
quant_ppl = quant_ppl.detach().cpu().numpy()
assert allveryclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}"
Expand Down Expand Up @@ -246,7 +260,7 @@ def acc_args_and_acc_pt_ge_2_4(default_run_args, request):
def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
caplog.set_level(logging.INFO)
args, exp_float_ppl, exp_quant_ppl = acc_args_and_acc_pt_ge_2_4
float_ppl, quant_ppl, model = main(args)
float_ppl, quant_ppl, model = validate_args_and_run_main(args)
float_ppl = float_ppl.detach().cpu().numpy()
quant_ppl = quant_ppl.detach().cpu().numpy()
assert allveryclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}"
Expand Down Expand Up @@ -365,7 +379,7 @@ def layer_args(default_run_args, request):
def test_small_models_quant_layer(caplog, layer_args):
caplog.set_level(logging.INFO)
args, exp_layer_types = layer_args
float_ppl, quant_ppl, model = main(args)
float_ppl, quant_ppl, model = validate_args_and_run_main(args)
assert_layer_types(model, exp_layer_types)


Expand Down Expand Up @@ -395,7 +409,7 @@ def layer_args_pt_ge_2_4(default_run_args, request):
def test_small_models_quant_layer_pt_ge_2_4(caplog, layer_args_pt_ge_2_4):
caplog.set_level(logging.INFO)
args, exp_layer_types = layer_args_pt_ge_2_4
float_ppl, quant_ppl, model = main(args)
float_ppl, quant_ppl, model = validate_args_and_run_main(args)
assert_layer_types(model, exp_layer_types)


Expand Down Expand Up @@ -427,7 +441,7 @@ def onnx_export_args(default_run_args, request):
def test_small_models_onnx_export(caplog, onnx_export_args):
caplog.set_level(logging.INFO)
args = onnx_export_args
float_ppl, quant_ppl, model = main(args)
float_ppl, quant_ppl, model = validate_args_and_run_main(args)
onnx_model = onnx.load(os.path.join(args.export_prefix, "model.onnx"))
shutil.rmtree(args.export_prefix)

Expand Down Expand Up @@ -462,7 +476,7 @@ def torch_export_args(default_run_args, request):
def test_small_models_torch_export(caplog, torch_export_args):
caplog.set_level(logging.INFO)
args = torch_export_args
float_ppl, quant_ppl, model = main(args)
float_ppl, quant_ppl, model = validate_args_and_run_main(args)
filepath = args.export_prefix + ".pt"
torchscript_model = torch.jit.load(filepath)
os.remove(filepath)

0 comments on commit 6e9283a

Please sign in to comment.