From 6b1e51a1f6f748988ac791f279d89ccf6c84753d Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Thu, 21 Nov 2024 14:30:11 +0000 Subject: [PATCH] test (fix): sdpa import --- tests/brevitas/nn/test_sdpa.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/brevitas/nn/test_sdpa.py b/tests/brevitas/nn/test_sdpa.py index 9d735a997..856f1ef01 100644 --- a/tests/brevitas/nn/test_sdpa.py +++ b/tests/brevitas/nn/test_sdpa.py @@ -4,7 +4,7 @@ from packaging import version import pytest import torch -from torch.nn.functional import scaled_dot_product_attention +import torch.nn.functional as F from brevitas import torch_version from brevitas.nn import QuantScaledDotProductAttention @@ -50,7 +50,7 @@ def test_sdpa_fwd(self, dropout_p, is_causal, scale, enable_gqa, rand_attn_mask) attn_mask = None if dropout_p > 0.0: torch.manual_seed(DROPOUT_SEED) - ref_out = scaled_dot_product_attention(q, k, v, attn_mask, **extra_kwargs) + ref_out = F.scaled_dot_product_attention(q, k, v, attn_mask, **extra_kwargs) if dropout_p > 0.0: torch.manual_seed(DROPOUT_SEED) out = m(q, k, v, attn_mask, **extra_kwargs)