From 32c94b96e0b693089a2d32f20751cd3b7a536178 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Thu, 28 Nov 2024 16:46:58 +0000 Subject: [PATCH] test (nn/sdpa): Fix when PT<2.1 --- tests/brevitas/nn/test_sdpa.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/brevitas/nn/test_sdpa.py b/tests/brevitas/nn/test_sdpa.py index 856f1ef01..efb6e4aa5 100644 --- a/tests/brevitas/nn/test_sdpa.py +++ b/tests/brevitas/nn/test_sdpa.py @@ -37,6 +37,8 @@ def test_sdpa_fwd(self, dropout_p, is_causal, scale, enable_gqa, rand_attn_mask) "enable_gqa": enable_gqa,} if torch_version < version.parse('2.5.0'): del extra_kwargs["enable_gqa"] + if torch_version < version.parse('2.1.0'): + del extra_kwargs["scale"] kv_length = PAST_SEQUENCE_LENGTH + SEQUENCE_LENGTH m = ScaledDotProductAttention()