From 24d84c10a203e219db28bd9c95890ca6448a612e Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 17 Sep 2024 02:37:15 +0100 Subject: [PATCH] Fix test structure, hopefully faster --- .../test_torchvision_models.py | 72 ++++++++++--------- 1 file changed, 40 insertions(+), 32 deletions(-) diff --git a/tests/brevitas_end_to_end/test_torchvision_models.py b/tests/brevitas_end_to_end/test_torchvision_models.py index f00920f3e..09f0b9253 100644 --- a/tests/brevitas_end_to_end/test_torchvision_models.py +++ b/tests/brevitas_end_to_end/test_torchvision_models.py @@ -26,6 +26,9 @@ BATCH = 1 HEIGHT, WIDTH = 224, 224 IN_CH = 3 + +COMPILE_MODEL_LIST = ['efficientnet_b0', 'resnet18', 'fcn_resnet50'] + MODEL_LIST = [ 'vit_b_32', 'efficientnet_b0', @@ -70,11 +73,7 @@ def quantize_float(model): quant_format='float') -@fixture -@parametrize('model_name', MODEL_LIST) -@parametrize('quantize_fn', [quantize_float, quantize, layerwise_quantize, quantize_flexml]) -def torchvision_model(model_name, quantize_fn): - +def shared_quant_fn(model_name, quantize_fn): inp = torch.randn(BATCH, IN_CH, HEIGHT, WIDTH) if torch_version <= version.parse('1.9.1') and model_name == 'regnet_x_400mf': @@ -114,44 +113,53 @@ def torchvision_model(model_name, quantize_fn): return model -@requires_pt_ge('1.8.1') -@parametrize('enable_compile', [True, False]) -def test_torchvision_graph_quantization_flexml_qcdq_onnx( - torchvision_model, enable_compile, request): - test_id = request.node.callspec.id - if torchvision_model is None: +@fixture +@parametrize('model_name', MODEL_LIST) +@parametrize('quantize_fn', [quantize_float, quantize, layerwise_quantize, quantize_flexml]) +def torchvision_model(model_name, quantize_fn): + return shared_quant_fn(model_name, quantize_fn) + + +@fixture +@parametrize('model_name', COMPILE_MODEL_LIST) +@parametrize('quantize_fn', [quantize_float, quantize]) +def torchvision_model_compile(model_name, quantize_fn): + return shared_quant_fn(model_name, quantize_fn) + + +@requires_pt_ge('2.2') +def test_torchvision_compile(torchvision_model_compile): + torch._dynamo.config.capture_scalar_outputs = True + if torchvision_model_compile is None: pytest.skip('Model not instantiated') - if enable_compile: - model_name = test_id.split("-")[1] - quant_func = test_id.split("-")[0] - if torch_version <= version.parse('2.2'): - pytest.skip("Pytorch 2.2 is required to test compile") - elif quant_func not in ('quantize_float', 'quantize'): - pytest.skip("Compile is tested only against base float and int quantization functions") - else: - torch._dynamo.config.capture_scalar_outputs = True - if 'vit' in model_name: - pytest.skip("QuantMHA not supported with compile") inp = torch.randn(BATCH, IN_CH, HEIGHT, WIDTH) - quantize_fn_name = test_id.split("-")[0] - with torch.no_grad(), quant_inference_mode(torchvision_model): - prehook_non_compiled_out = torchvision_model(inp) - post_hook_non_compiled_out = torchvision_model(inp) + with torch.no_grad(), quant_inference_mode(torchvision_model_compile): + prehook_non_compiled_out = torchvision_model_compile(inp) + post_hook_non_compiled_out = torchvision_model_compile(inp) + + compiled_model = torch.compile(torchvision_model_compile, fullgraph=True) + compiled_out = compiled_model(inp) + assert torch.allclose(prehook_non_compiled_out, post_hook_non_compiled_out) + assert torch.allclose(post_hook_non_compiled_out, compiled_out, atol=TORCH_COMPILE_ATOL) + - if enable_compile: - compiled_model = torch.compile(torchvision_model, fullgraph=True) - compiled_out = compiled_model(inp) +def test_torchvision_graph_quantization_flexml_qcdq_onnx(torchvision_model, request): + test_id = request.node.callspec.id + if torchvision_model is None: + pytest.skip('Model not instantiated') - assert torch.allclose(post_hook_non_compiled_out, compiled_out, atol=TORCH_COMPILE_ATOL) + inp = torch.randn(BATCH, IN_CH, HEIGHT, WIDTH) + + quantize_fn_name = test_id.split("-")[0] + torchvision_model(inp) - if quantize_fn_name != 'quantize_float' and not enable_compile: + if quantize_fn_name != 'quantize_float': export_onnx_qcdq(torchvision_model, args=inp) -@requires_pt_ge('1.9.1') def test_torchvision_graph_quantization_flexml_qcdq_torch(torchvision_model, request): if torchvision_model is None: pytest.skip('Model not instantiated')