Skip to content

Commit

Permalink
test (nn/sdpa): Fix when PT<2.1
Browse files Browse the repository at this point in the history
  • Loading branch information
nickfraser committed Nov 28, 2024
1 parent 6b1e51a commit 32c94b9
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tests/brevitas/nn/test_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def test_sdpa_fwd(self, dropout_p, is_causal, scale, enable_gqa, rand_attn_mask)
"enable_gqa": enable_gqa,}
if torch_version < version.parse('2.5.0'):
del extra_kwargs["enable_gqa"]
if torch_version < version.parse('2.1.0'):
del extra_kwargs["scale"]

kv_length = PAST_SEQUENCE_LENGTH + SEQUENCE_LENGTH
m = ScaledDotProductAttention()
Expand Down

0 comments on commit 32c94b9

Please sign in to comment.