Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alternative approach to support torch.compile #1006

Merged
merged 25 commits into from
Sep 23, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Fix test structure, hopefully faster
Giuseppe5 committed Sep 17, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit 24d84c10a203e219db28bd9c95890ca6448a612e
72 changes: 40 additions & 32 deletions tests/brevitas_end_to_end/test_torchvision_models.py
Original file line number Diff line number Diff line change
@@ -26,6 +26,9 @@
BATCH = 1
HEIGHT, WIDTH = 224, 224
IN_CH = 3

COMPILE_MODEL_LIST = ['efficientnet_b0', 'resnet18', 'fcn_resnet50']

MODEL_LIST = [
'vit_b_32',
'efficientnet_b0',
@@ -70,11 +73,7 @@ def quantize_float(model):
quant_format='float')


@fixture
@parametrize('model_name', MODEL_LIST)
@parametrize('quantize_fn', [quantize_float, quantize, layerwise_quantize, quantize_flexml])
def torchvision_model(model_name, quantize_fn):

def shared_quant_fn(model_name, quantize_fn):
inp = torch.randn(BATCH, IN_CH, HEIGHT, WIDTH)

if torch_version <= version.parse('1.9.1') and model_name == 'regnet_x_400mf':
@@ -114,44 +113,53 @@ def torchvision_model(model_name, quantize_fn):
return model


@requires_pt_ge('1.8.1')
@parametrize('enable_compile', [True, False])
def test_torchvision_graph_quantization_flexml_qcdq_onnx(
torchvision_model, enable_compile, request):
test_id = request.node.callspec.id
if torchvision_model is None:
@fixture
@parametrize('model_name', MODEL_LIST)
@parametrize('quantize_fn', [quantize_float, quantize, layerwise_quantize, quantize_flexml])
def torchvision_model(model_name, quantize_fn):
return shared_quant_fn(model_name, quantize_fn)


@fixture
@parametrize('model_name', COMPILE_MODEL_LIST)
@parametrize('quantize_fn', [quantize_float, quantize])
def torchvision_model_compile(model_name, quantize_fn):
return shared_quant_fn(model_name, quantize_fn)


@requires_pt_ge('2.2')
def test_torchvision_compile(torchvision_model_compile):
torch._dynamo.config.capture_scalar_outputs = True
if torchvision_model_compile is None:
pytest.skip('Model not instantiated')
if enable_compile:
model_name = test_id.split("-")[1]
quant_func = test_id.split("-")[0]
if torch_version <= version.parse('2.2'):
pytest.skip("Pytorch 2.2 is required to test compile")
elif quant_func not in ('quantize_float', 'quantize'):
pytest.skip("Compile is tested only against base float and int quantization functions")
else:
torch._dynamo.config.capture_scalar_outputs = True
if 'vit' in model_name:
pytest.skip("QuantMHA not supported with compile")

inp = torch.randn(BATCH, IN_CH, HEIGHT, WIDTH)

quantize_fn_name = test_id.split("-")[0]
with torch.no_grad(), quant_inference_mode(torchvision_model):
prehook_non_compiled_out = torchvision_model(inp)
post_hook_non_compiled_out = torchvision_model(inp)
with torch.no_grad(), quant_inference_mode(torchvision_model_compile):
prehook_non_compiled_out = torchvision_model_compile(inp)
post_hook_non_compiled_out = torchvision_model_compile(inp)

compiled_model = torch.compile(torchvision_model_compile, fullgraph=True)
compiled_out = compiled_model(inp)

assert torch.allclose(prehook_non_compiled_out, post_hook_non_compiled_out)
assert torch.allclose(post_hook_non_compiled_out, compiled_out, atol=TORCH_COMPILE_ATOL)


if enable_compile:
compiled_model = torch.compile(torchvision_model, fullgraph=True)
compiled_out = compiled_model(inp)
def test_torchvision_graph_quantization_flexml_qcdq_onnx(torchvision_model, request):
test_id = request.node.callspec.id
if torchvision_model is None:
pytest.skip('Model not instantiated')

assert torch.allclose(post_hook_non_compiled_out, compiled_out, atol=TORCH_COMPILE_ATOL)
inp = torch.randn(BATCH, IN_CH, HEIGHT, WIDTH)

quantize_fn_name = test_id.split("-")[0]
torchvision_model(inp)

if quantize_fn_name != 'quantize_float' and not enable_compile:
if quantize_fn_name != 'quantize_float':
export_onnx_qcdq(torchvision_model, args=inp)


@requires_pt_ge('1.9.1')
def test_torchvision_graph_quantization_flexml_qcdq_torch(torchvision_model, request):
if torchvision_model is None:
pytest.skip('Model not instantiated')