From 5e31f2ffca2772daff25b530c383217190a1ec19 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sun, 26 Nov 2023 14:52:12 +0000 Subject: [PATCH] Fix (tests): skip some MHA tests with torch 2.0.1 --- tests/brevitas/graph/equalization_fixtures.py | 12 ++++++++++++ tests/brevitas/graph/test_equalization.py | 7 +++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/tests/brevitas/graph/equalization_fixtures.py b/tests/brevitas/graph/equalization_fixtures.py index 128f3953a..8e13c49cd 100644 --- a/tests/brevitas/graph/equalization_fixtures.py +++ b/tests/brevitas/graph/equalization_fixtures.py @@ -109,6 +109,10 @@ def linearmha_model(bias, add_bias_kv, batch_first): if torch_version < version.parse('1.9.1'): pytest.skip(f"batch_first not supported in MHA with torch version {torch_version}") + # Skip due to following issue https://github.com/pytorch/pytorch/issues/97128 + if torch_version == version.parse('2.0.1') and not bias and batch_first and not add_bias_kv: + pytest.skip(f"Skip due to a regression in pytorch 2.0.1") + class LinearMhaModel(nn.Module): def __init__(self) -> None: @@ -135,6 +139,10 @@ def layernormmha_model(bias, add_bias_kv, batch_first): if torch_version < version.parse('1.9.1'): pytest.skip(f"batch_first not supported in MHA with torch version {torch_version}") + # Skip due to following issue https://github.com/pytorch/pytorch/issues/97128 + if torch_version == version.parse('2.0.1') and not bias and batch_first and not add_bias_kv: + pytest.skip(f"Skip due to a regression in pytorch 2.0.1") + class LayerNormMhaModel(nn.Module): def __init__(self) -> None: @@ -164,6 +172,10 @@ def mhalinear_model(bias, add_bias_kv, batch_first): if torch_version < version.parse('1.9.1'): pytest.skip(f"batch_first not supported in MHA with torch version {torch_version}") + # Skip due to following issue https://github.com/pytorch/pytorch/issues/97128 + if torch_version == version.parse('2.0.1') and not bias and batch_first and not add_bias_kv: + pytest.skip(f"Skip due to a regression in pytorch 2.0.1") + class MhaLinearModel(nn.Module): def __init__(self) -> None: diff --git a/tests/brevitas/graph/test_equalization.py b/tests/brevitas/graph/test_equalization.py index 7ae9022b7..caca0fd29 100644 --- a/tests/brevitas/graph/test_equalization.py +++ b/tests/brevitas/graph/test_equalization.py @@ -124,7 +124,9 @@ def test_models(toy_model, merge_bias, request): inp = torch.randn(in_shape) model.eval() - expected_out = model(inp) + with torch.no_grad(): + expected_out = model(inp) + model = symbolic_trace(model) regions = _extract_regions(model) scale_factor_regions = equalize_test( @@ -135,7 +137,8 @@ def test_models(toy_model, merge_bias, request): scale_computation_type='maxabs') shape_scale_regions = [scale.shape for scale in scale_factor_regions] - out = model(inp) + with torch.no_grad(): + out = model(inp) assert len(regions) > 0 assert torch.allclose(expected_out, out, atol=ATOL) # Check that at least one region performs "true" equalization