Skip to content

Commit

Permalink
test (nn): Added quant_disabled QSPDA test
Browse files Browse the repository at this point in the history
  • Loading branch information
nickfraser committed Nov 28, 2024
1 parent e3d3c4a commit 3798923
Showing 1 changed file with 43 additions and 1 deletion.
44 changes: 43 additions & 1 deletion tests/brevitas/nn/test_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class TestScaledDotProductAttention:
@pytest.mark.parametrize("scale", [None, 0.3])
@pytest.mark.parametrize("enable_gqa", [False, True])
@pytest.mark.parametrize("rand_attn_mask", [False, True])
# Sanity check, since `ScaledDotProductAttention` just called `F.scaled_dot_product_attention` in its forward function
# Sanity check, since `ScaledDotProductAttention` just calls `F.scaled_dot_product_attention` in its forward function
def test_sdpa_fwd(self, dropout_p, is_causal, scale, enable_gqa, rand_attn_mask):
extra_kwargs = {
"dropout_p": dropout_p,
Expand Down Expand Up @@ -56,3 +56,45 @@ def test_sdpa_fwd(self, dropout_p, is_causal, scale, enable_gqa, rand_attn_mask)
out = m(q, k, v, attn_mask, **extra_kwargs)
assert torch.isclose(out, ref_out, atol=ATOL).all()
assert torch.isclose(out, ref_out, atol=ATOL).all()

@requires_pt_ge('2.0')
@pytest.mark.parametrize("dropout_p", [0.0, 0.5])
@pytest.mark.parametrize("is_causal", [True, False])
@pytest.mark.parametrize("scale", [None, 0.3])
@pytest.mark.parametrize("enable_gqa", [False, True])
@pytest.mark.parametrize("rand_attn_mask", [False, True])
def test_sdpa_quant_disabled_fwd(self, dropout_p, is_causal, scale, enable_gqa, rand_attn_mask):
extra_kwargs = {
"dropout_p": dropout_p,
"is_causal": is_causal,
"scale": scale,
"enable_gqa": enable_gqa,}
if torch_version < version.parse('2.5.0'):
del extra_kwargs["enable_gqa"]

kv_length = PAST_SEQUENCE_LENGTH + SEQUENCE_LENGTH
m = ScaledDotProductAttention()
qm = QuantScaledDotProductAttention(
softmax_input_quant=None,
attn_output_weights_quant=None,
q_scaled_quant=None,
k_transposed_quant=None,
v_quant=None,
attn_output_quant=None,
)
q = torch.randn(BATCH_SIZE, HEAD_DIM, SEQUENCE_LENGTH, EMBED_DIM)
k = torch.randn(BATCH_SIZE, HEAD_DIM, kv_length, EMBED_DIM)
v = torch.randn(BATCH_SIZE, HEAD_DIM, kv_length, EMBED_DIM)
if rand_attn_mask and not is_causal:
attn_mask = torch.randint(
low=0, high=2, size=(BATCH_SIZE, 1, SEQUENCE_LENGTH, kv_length), dtype=torch.bool)
else:
attn_mask = None
if dropout_p > 0.0:
torch.manual_seed(DROPOUT_SEED)
ref_out = m(q, k, v, attn_mask, **extra_kwargs)
if dropout_p > 0.0:
torch.manual_seed(DROPOUT_SEED)
out = qm(q, k, v, attn_mask, **extra_kwargs)
assert torch.isclose(out, ref_out, atol=ATOL).all()
assert torch.isclose(out, ref_out, atol=ATOL).all()

0 comments on commit 3798923

Please sign in to comment.