Skip to content

Commit

Permalink
Fix (tests): skip some MHA tests with torch 2.0.1
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Nov 28, 2023
1 parent f0a8c4f commit 5e31f2f
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
12 changes: 12 additions & 0 deletions tests/brevitas/graph/equalization_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ def linearmha_model(bias, add_bias_kv, batch_first):
if torch_version < version.parse('1.9.1'):
pytest.skip(f"batch_first not supported in MHA with torch version {torch_version}")

# Skip due to following issue https://github.com/pytorch/pytorch/issues/97128
if torch_version == version.parse('2.0.1') and not bias and batch_first and not add_bias_kv:
pytest.skip(f"Skip due to a regression in pytorch 2.0.1")

class LinearMhaModel(nn.Module):

def __init__(self) -> None:
Expand All @@ -135,6 +139,10 @@ def layernormmha_model(bias, add_bias_kv, batch_first):
if torch_version < version.parse('1.9.1'):
pytest.skip(f"batch_first not supported in MHA with torch version {torch_version}")

# Skip due to following issue https://github.com/pytorch/pytorch/issues/97128
if torch_version == version.parse('2.0.1') and not bias and batch_first and not add_bias_kv:
pytest.skip(f"Skip due to a regression in pytorch 2.0.1")

class LayerNormMhaModel(nn.Module):

def __init__(self) -> None:
Expand Down Expand Up @@ -164,6 +172,10 @@ def mhalinear_model(bias, add_bias_kv, batch_first):
if torch_version < version.parse('1.9.1'):
pytest.skip(f"batch_first not supported in MHA with torch version {torch_version}")

# Skip due to following issue https://github.com/pytorch/pytorch/issues/97128
if torch_version == version.parse('2.0.1') and not bias and batch_first and not add_bias_kv:
pytest.skip(f"Skip due to a regression in pytorch 2.0.1")

class MhaLinearModel(nn.Module):

def __init__(self) -> None:
Expand Down
7 changes: 5 additions & 2 deletions tests/brevitas/graph/test_equalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ def test_models(toy_model, merge_bias, request):
inp = torch.randn(in_shape)

model.eval()
expected_out = model(inp)
with torch.no_grad():
expected_out = model(inp)

model = symbolic_trace(model)
regions = _extract_regions(model)
scale_factor_regions = equalize_test(
Expand All @@ -135,7 +137,8 @@ def test_models(toy_model, merge_bias, request):
scale_computation_type='maxabs')
shape_scale_regions = [scale.shape for scale in scale_factor_regions]

out = model(inp)
with torch.no_grad():
out = model(inp)
assert len(regions) > 0
assert torch.allclose(expected_out, out, atol=ATOL)
# Check that at least one region performs "true" equalization
Expand Down

0 comments on commit 5e31f2f

Please sign in to comment.