Skip to content

Commit

Permalink
Fix test structure, hopefully faster
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Sep 17, 2024
1 parent 6b49188 commit 24d84c1
Showing 1 changed file with 40 additions and 32 deletions.
72 changes: 40 additions & 32 deletions tests/brevitas_end_to_end/test_torchvision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
BATCH = 1
HEIGHT, WIDTH = 224, 224
IN_CH = 3

COMPILE_MODEL_LIST = ['efficientnet_b0', 'resnet18', 'fcn_resnet50']

MODEL_LIST = [
'vit_b_32',
'efficientnet_b0',
Expand Down Expand Up @@ -70,11 +73,7 @@ def quantize_float(model):
quant_format='float')


@fixture
@parametrize('model_name', MODEL_LIST)
@parametrize('quantize_fn', [quantize_float, quantize, layerwise_quantize, quantize_flexml])
def torchvision_model(model_name, quantize_fn):

def shared_quant_fn(model_name, quantize_fn):
inp = torch.randn(BATCH, IN_CH, HEIGHT, WIDTH)

if torch_version <= version.parse('1.9.1') and model_name == 'regnet_x_400mf':
Expand Down Expand Up @@ -114,44 +113,53 @@ def torchvision_model(model_name, quantize_fn):
return model


@requires_pt_ge('1.8.1')
@parametrize('enable_compile', [True, False])
def test_torchvision_graph_quantization_flexml_qcdq_onnx(
torchvision_model, enable_compile, request):
test_id = request.node.callspec.id
if torchvision_model is None:
@fixture
@parametrize('model_name', MODEL_LIST)
@parametrize('quantize_fn', [quantize_float, quantize, layerwise_quantize, quantize_flexml])
def torchvision_model(model_name, quantize_fn):
return shared_quant_fn(model_name, quantize_fn)


@fixture
@parametrize('model_name', COMPILE_MODEL_LIST)
@parametrize('quantize_fn', [quantize_float, quantize])
def torchvision_model_compile(model_name, quantize_fn):
return shared_quant_fn(model_name, quantize_fn)


@requires_pt_ge('2.2')
def test_torchvision_compile(torchvision_model_compile):
torch._dynamo.config.capture_scalar_outputs = True
if torchvision_model_compile is None:
pytest.skip('Model not instantiated')
if enable_compile:
model_name = test_id.split("-")[1]
quant_func = test_id.split("-")[0]
if torch_version <= version.parse('2.2'):
pytest.skip("Pytorch 2.2 is required to test compile")
elif quant_func not in ('quantize_float', 'quantize'):
pytest.skip("Compile is tested only against base float and int quantization functions")
else:
torch._dynamo.config.capture_scalar_outputs = True
if 'vit' in model_name:
pytest.skip("QuantMHA not supported with compile")

inp = torch.randn(BATCH, IN_CH, HEIGHT, WIDTH)

quantize_fn_name = test_id.split("-")[0]
with torch.no_grad(), quant_inference_mode(torchvision_model):
prehook_non_compiled_out = torchvision_model(inp)
post_hook_non_compiled_out = torchvision_model(inp)
with torch.no_grad(), quant_inference_mode(torchvision_model_compile):
prehook_non_compiled_out = torchvision_model_compile(inp)
post_hook_non_compiled_out = torchvision_model_compile(inp)

compiled_model = torch.compile(torchvision_model_compile, fullgraph=True)
compiled_out = compiled_model(inp)

assert torch.allclose(prehook_non_compiled_out, post_hook_non_compiled_out)
assert torch.allclose(post_hook_non_compiled_out, compiled_out, atol=TORCH_COMPILE_ATOL)


if enable_compile:
compiled_model = torch.compile(torchvision_model, fullgraph=True)
compiled_out = compiled_model(inp)
def test_torchvision_graph_quantization_flexml_qcdq_onnx(torchvision_model, request):
test_id = request.node.callspec.id
if torchvision_model is None:
pytest.skip('Model not instantiated')

assert torch.allclose(post_hook_non_compiled_out, compiled_out, atol=TORCH_COMPILE_ATOL)
inp = torch.randn(BATCH, IN_CH, HEIGHT, WIDTH)

quantize_fn_name = test_id.split("-")[0]
torchvision_model(inp)

if quantize_fn_name != 'quantize_float' and not enable_compile:
if quantize_fn_name != 'quantize_float':
export_onnx_qcdq(torchvision_model, args=inp)


@requires_pt_ge('1.9.1')
def test_torchvision_graph_quantization_flexml_qcdq_torch(torchvision_model, request):
if torchvision_model is None:
pytest.skip('Model not instantiated')
Expand Down

0 comments on commit 24d84c1

Please sign in to comment.