From f0701bae61cb88e960d68d8a4cb3c7663df80f2c Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Thu, 5 Dec 2024 13:19:40 +0000 Subject: [PATCH] fix (nn/sdpa): Rename output quantizer to sdpa_output_quant to avoid name clashes --- src/brevitas/nn/quant_sdpa.py | 8 ++++---- tests/brevitas/nn/test_sdpa.py | 12 ++++++------ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/brevitas/nn/quant_sdpa.py b/src/brevitas/nn/quant_sdpa.py index 96e85d489..43f99e827 100644 --- a/src/brevitas/nn/quant_sdpa.py +++ b/src/brevitas/nn/quant_sdpa.py @@ -121,7 +121,7 @@ def __init__( q_scaled_quant=Int8ActPerTensorFloat, k_transposed_quant=Int8ActPerTensorFloat, v_quant=Int8ActPerTensorFloat, - attn_output_quant=None, + sdpa_output_quant=None, **kwargs) -> None: super(QuantScaledDotProductAttention, self).__init__() @@ -136,8 +136,8 @@ def filter_kwargs(prefix): act_quant=softmax_input_quant, **filter_kwargs('softmax_input_')) self.attn_output_weights_quant = QuantIdentity( act_quant=attn_output_weights_quant, **filter_kwargs('attn_output_weights_')) - self.attn_output_quant = QuantIdentity( - act_quant=attn_output_quant, **filter_kwargs('attn_output_')) + self.sdpa_output_quant = QuantIdentity( + act_quant=sdpa_output_quant, **filter_kwargs('sdpa_output_')) def forward( self, @@ -205,5 +205,5 @@ def forward( attn_weight = torch.dropout(attn_weight, dropout_p, train=True) attn_weight = self.attn_output_weights_quant(attn_weight) attn_output = attn_weight @ self.v_quant(value) - attn_output = self.attn_output_quant(attn_output) + attn_output = self.sdpa_output_quant(attn_output) return attn_output diff --git a/tests/brevitas/nn/test_sdpa.py b/tests/brevitas/nn/test_sdpa.py index 281429f0f..b38415ea8 100644 --- a/tests/brevitas/nn/test_sdpa.py +++ b/tests/brevitas/nn/test_sdpa.py @@ -33,19 +33,19 @@ def test_sdpa_init(self): 'q_scaled_bit_width': 4, 'k_transposed_bit_width': 5, 'v_bit_width': 6, - 'attn_output_bit_width': 7,} + 'sdpa_output_bit_width': 7,} qm = QuantScaledDotProductAttention( softmax_input_quant=Int8ActPerTensorFloat, attn_output_weights_quant=Uint8ActPerTensorFloat, q_scaled_quant=Int8ActPerTensorFloat, k_transposed_quant=Int8ActPerTensorFloat, v_quant=Int8ActPerTensorFloat, - attn_output_quant=Int8ActPerTensorFloat, + sdpa_output_quant=Int8ActPerTensorFloat, **extra_kwargs, ) # Check that the `kwargs` have been applied correctly - prefixes = ["softmax_input", "attn_output", "q_scaled", "v", "attn_output"] + prefixes = ["softmax_input", "attn_output_weights", "q_scaled", "v", "sdpa_output"] for k in extra_kwargs.keys(): checked = False if "softmax_input_" in k: @@ -64,8 +64,8 @@ def test_sdpa_init(self): elif "v_" in k: assert int(qm.v_quant.act_quant.bit_width().item()) == extra_kwargs[k] checked = True - elif "attn_output_" in k: - assert int(qm.attn_output_quant.act_quant.bit_width().item()) == extra_kwargs[k] + elif "sdpa_output_" in k: + assert int(qm.sdpa_output_quant.act_quant.bit_width().item()) == extra_kwargs[k] checked = True assert checked, f"Unmatched kwarg: {k}" @@ -131,7 +131,7 @@ def test_sdpa_quant_disabled_fwd(self, dropout_p, is_causal, scale, enable_gqa, q_scaled_quant=None, k_transposed_quant=None, v_quant=None, - attn_output_quant=None, + sdpa_output_quant=None, ) q = torch.randn(BATCH_SIZE, HEAD_DIM, SEQUENCE_LENGTH, EMBED_DIM) k = torch.randn(BATCH_SIZE, HEAD_DIM, kv_length, EMBED_DIM)