diff --git a/tests/brevitas/nn/test_sdpa.py b/tests/brevitas/nn/test_sdpa.py index 856f1ef01..efb6e4aa5 100644 --- a/tests/brevitas/nn/test_sdpa.py +++ b/tests/brevitas/nn/test_sdpa.py @@ -37,6 +37,8 @@ def test_sdpa_fwd(self, dropout_p, is_causal, scale, enable_gqa, rand_attn_mask) "enable_gqa": enable_gqa,} if torch_version < version.parse('2.5.0'): del extra_kwargs["enable_gqa"] + if torch_version < version.parse('2.1.0'): + del extra_kwargs["scale"] kv_length = PAST_SEQUENCE_LENGTH + SEQUENCE_LENGTH m = ScaledDotProductAttention()