From 85c16262ace2a0ad19cc76f813bfaa81acab60ae Mon Sep 17 00:00:00 2001 From: nickfraser Date: Fri, 6 Dec 2024 09:44:45 +0000 Subject: [PATCH] Feat (nn/sdpa): quantization of scaled dot-product attention (#1090) --- src/brevitas/nn/__init__.py | 2 + src/brevitas/nn/quant_sdpa.py | 209 ++++++++++++++++++ .../common/generative/quantize.py | 14 +- src/brevitas_examples/llm/README.md | 10 +- .../llm/llm_quant/prepare_for_quantize.py | 10 + src/brevitas_examples/llm/main.py | 17 +- tests/brevitas/nn/test_sdpa.py | 151 +++++++++++++ tests/brevitas_examples/test_llm.py | 21 +- 8 files changed, 425 insertions(+), 9 deletions(-) create mode 100644 src/brevitas/nn/quant_sdpa.py create mode 100644 tests/brevitas/nn/test_sdpa.py diff --git a/src/brevitas/nn/__init__.py b/src/brevitas/nn/__init__.py index 4176e30cf..e58a5fa0c 100644 --- a/src/brevitas/nn/__init__.py +++ b/src/brevitas/nn/__init__.py @@ -28,6 +28,8 @@ from .quant_rnn import QuantRNN from .quant_scale_bias import QuantScaleBias from .quant_scale_bias import ScaleBias +from .quant_sdpa import QuantScaledDotProductAttention +from .quant_sdpa import ScaledDotProductAttention from .quant_upsample import QuantUpsample from .quant_upsample import QuantUpsamplingBilinear2d from .quant_upsample import QuantUpsamplingNearest2d diff --git a/src/brevitas/nn/quant_sdpa.py b/src/brevitas/nn/quant_sdpa.py new file mode 100644 index 000000000..43f99e827 --- /dev/null +++ b/src/brevitas/nn/quant_sdpa.py @@ -0,0 +1,209 @@ +""" +Copyright (C) 2024, Advanced Micro Devices, Inc. +Copyright (c) 2016- Facebook, Inc (Adam Paszke) +Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +Copyright (c) 2011-2013 NYU (Clement Farabet) +Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the names of AMD, Facebook, Deepmind Technologies, NYU, + NEC Laboratories America and IDIAP Research Institute nor the names + of its contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. +""" + +import math +from typing import Optional, Tuple, Union + +import torch +from torch import Tensor +from torch.nn import Module +from torch.nn import Parameter +import torch.nn.functional as F + +from brevitas.quant.scaled_int import Int8ActPerTensorFloat +from brevitas.quant.scaled_int import Uint8ActPerTensorFloat + +from .quant_activation import QuantIdentity + + +class ScaledDotProductAttention(Module): + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + attn_mask: Optional[Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False): + r""" + Args: + query (Tensor): Query tensor; shape :math:`(N, ..., Hq, L, E)`. + key (Tensor): Key tensor; shape :math:`(N, ..., H, S, E)`. + value (Tensor): Value tensor; shape :math:`(N, ..., H, S, Ev)`. + attn_mask (optional Tensor): Attention mask; shape must be broadcastable to the shape of attention weights, + which is :math:`(N,..., L, S)`. Two types of masks are supported. + A boolean mask where a value of True indicates that the element *should* take part in attention. + A float mask of the same type as query, key, value that is added to the attention score. + dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied + is_causal (bool): If set to true, the attention masking is a lower triangular matrix when the mask is a + square matrix. The attention masking has the form of the upper left causal bias due to the alignment + (see :class:`torch.nn.attention.bias.CausalBias`) when the mask is a non-square matrix. + An error is thrown if both attn_mask and is_causal are set. + scale (optional float, keyword-only): Scaling factor applied prior to softmax. If None, the default value is set + to :math:`\frac{1}{\sqrt{E}}`. + enable_gqa (bool): Ignored to make calling interface compatible with PyTorch >v2.5. Always set to False. + + Returns: + output (Tensor): Attention output; shape :math:`(N, ..., Hq, L, Ev)`. + + Shape legend: + - :math:`N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}` + - :math:`S: \text{Source sequence length}` + - :math:`L: \text{Target sequence length}` + - :math:`E: \text{Embedding dimension of the query and key}` + - :math:`Ev: \text{Embedding dimension of the value}` + - :math:`Hq: \text{Number of heads of query}` + - :math:`H: \text{Number of heads of key and value}` + """ + kwargs = {} + if scale is not None: + kwargs["scale"] = scale + if enable_gqa: + kwargs["enable_gqa"] = enable_gqa + return F.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + **kwargs) + + +class QuantScaledDotProductAttention(Module): + + def __init__( + self, + softmax_input_quant=None, + attn_output_weights_quant=Uint8ActPerTensorFloat, + q_scaled_quant=Int8ActPerTensorFloat, + k_transposed_quant=Int8ActPerTensorFloat, + v_quant=Int8ActPerTensorFloat, + sdpa_output_quant=None, + **kwargs) -> None: + super(QuantScaledDotProductAttention, self).__init__() + + def filter_kwargs(prefix): + return {k[len(prefix):]: v for k, v in kwargs.items() if k.startswith(prefix)} + + self.q_scaled_quant = QuantIdentity(act_quant=q_scaled_quant, **filter_kwargs('q_scaled_')) + self.k_transposed_quant = QuantIdentity( + act_quant=k_transposed_quant, **filter_kwargs('k_transposed_')) + self.v_quant = QuantIdentity(act_quant=v_quant, **filter_kwargs('v_')) + self.softmax_input_quant = QuantIdentity( + act_quant=softmax_input_quant, **filter_kwargs('softmax_input_')) + self.attn_output_weights_quant = QuantIdentity( + act_quant=attn_output_weights_quant, **filter_kwargs('attn_output_weights_')) + self.sdpa_output_quant = QuantIdentity( + act_quant=sdpa_output_quant, **filter_kwargs('sdpa_output_')) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + attn_mask: Optional[Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False): + r""" + Args: + query (Tensor): Query tensor; shape :math:`(N, ..., Hq, L, E)`. + key (Tensor): Key tensor; shape :math:`(N, ..., H, S, E)`. + value (Tensor): Value tensor; shape :math:`(N, ..., H, S, Ev)`. + attn_mask (optional Tensor): Attention mask; shape must be broadcastable to the shape of attention weights, + which is :math:`(N,..., L, S)`. Two types of masks are supported. + A boolean mask where a value of True indicates that the element *should* take part in attention. + A float mask of the same type as query, key, value that is added to the attention score. + dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied + is_causal (bool): If set to true, the attention masking is a lower triangular matrix when the mask is a + square matrix. The attention masking has the form of the upper left causal bias due to the alignment + (see :class:`torch.nn.attention.bias.CausalBias`) when the mask is a non-square matrix. + An error is thrown if both attn_mask and is_causal are set. + scale (optional float, keyword-only): Scaling factor applied prior to softmax. If None, the default value is set + to :math:`\frac{1}{\sqrt{E}}`. + enable_gqa (bool): Ignored to make calling interface compatible with PyTorch >v2.5. Always set to False. + + Returns: + output (Tensor): Attention output; shape :math:`(N, ..., Hq, L, Ev)`. + + Shape legend: + - :math:`N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}` + - :math:`S: \text{Source sequence length}` + - :math:`L: \text{Target sequence length}` + - :math:`E: \text{Embedding dimension of the query and key}` + - :math:`Ev: \text{Embedding dimension of the value}` + - :math:`Hq: \text{Number of heads of query}` + - :math:`H: \text{Number of heads of key and value}` + """ + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale + if attn_mask is None: + attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) + else: + attn_bias = torch.zeros(size=attn_mask.shape, dtype=query.dtype, device=query.device) + if is_causal: + assert attn_mask is None + temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(dtype=query.dtype, device=query.device) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias += attn_mask + q_scaled = self.q_scaled_quant(query * scale_factor) + k_transpose = self.k_transposed_quant(key.transpose(-2, -1)) + attn_weight = q_scaled @ k_transpose + attn_weight += attn_bias + attn_weight = self.softmax_input_quant(attn_weight) + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + attn_weight = self.attn_output_weights_quant(attn_weight) + attn_output = attn_weight @ self.v_quant(value) + attn_output = self.sdpa_output_quant(attn_output) + return attn_output diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 7e9b9c897..6c156ec1a 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -463,13 +463,25 @@ def generate_quant_maps( 'dtype': dtype, 'device': device} + quant_sdpa_kwargs = { + 'softmax_input_quant': None, + 'attn_output_weights_quant': attn_output_weights_quant, + 'attn_output_weights_signed': 'float' in input_quant_format, + 'q_scaled_quant': q_scaled_quant, + 'k_transposed_quant': k_transposed_quant, + 'v_quant': v_quant, + 'attn_output_quant': None, + 'dtype': dtype, + 'device': device} + layer_map = { nn.Linear: (qnn.QuantLinear, quant_linear_kwargs), nn.Conv2d: (qnn.QuantConv2d, quant_conv_kwargs), 'diffusers.models.lora.LoRACompatibleLinear': (LoRACompatibleQuantLinear, quant_linear_kwargs), 'diffusers.models.lora.LoRACompatibleConv': (LoRACompatibleQuantConv2d, quant_conv_kwargs), - nn.MultiheadAttention: (qnn.QuantMultiheadAttention, quant_mha_kwargs)} + nn.MultiheadAttention: (qnn.QuantMultiheadAttention, quant_mha_kwargs), + qnn.ScaledDotProductAttention: (qnn.QuantScaledDotProductAttention, quant_sdpa_kwargs)} if quantize_embedding: quant_embedding_kwargs = {'weight_quant': weight_quant, 'dtype': dtype, 'device': device} diff --git a/src/brevitas_examples/llm/README.md b/src/brevitas_examples/llm/README.md index 64c82c80a..8471c84ee 100644 --- a/src/brevitas_examples/llm/README.md +++ b/src/brevitas_examples/llm/README.md @@ -10,7 +10,7 @@ ## Run -Set the env variable BREVITAS_JIT=1 to speed up the quantization process. Currently unsupported whenever export is also toggled or with MSE based scales/zero-points. +Set the env variable `BREVITAS_JIT=1` to speed up the quantization process. Currently unsupported whenever export is also toggled or with MSE based scales/zero-points. ```bash usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES] @@ -46,8 +46,9 @@ usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES] [--act-calibration] [--bias-corr] [--ln-affine-merge] [--convert-layernorm-to-rmsnorm] [--replace-rmsnorm] [--no-quantize] [--no-float16] - [--scaling-min-val SCALING_MIN_VAL] [--replace-mha] - [--weight-equalization] [--rotation {fx,layerwise,fused_no_fx}] + [--scaling-min-val SCALING_MIN_VAL] [--quant-sdpa] + [--replace-mha] [--weight-equalization] + [--rotation {fx,layerwise,fused_no_fx}] [--rotation-mode {had,ort}] [--rotation-orphan-sink] [--act-equalization {None,layerwise,fx}] [--load-awq LOAD_AWQ] [--export-target {None,onnx_qcdq,torch_qcdq,sharded_torchmlir_group_weight,sharded_packed_torchmlir_group_weight}] @@ -160,6 +161,8 @@ options: --scaling-min-val SCALING_MIN_VAL Minimum value to clamp scale to when using bf16 or fp16 quantization. + --quant-sdpa Quantize `F.scaled_dot_product_attention` (default: + False) --replace-mha Replace HuggingFace Attention with a quantizable version --weight-equalization @@ -200,5 +203,4 @@ options: --learned-round-fast-update Whether to use fast update with learned round. Prototype (default: False) - ``` diff --git a/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py b/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py index ee2f0b3e6..10fe8325a 100644 --- a/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py +++ b/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py @@ -5,10 +5,14 @@ from packaging import version import torch +import torch.nn.functional as F import transformers from transformers.models.opt.modeling_opt import OPTAttention from brevitas.graph import ModuleToModuleByClass +from brevitas.graph import TorchFunctionalToModule +from brevitas.nn import QuantScaledDotProductAttention +from brevitas.nn import ScaledDotProductAttention from brevitas_examples.llm.llm_quant.mha_layers import QuantizableOPTAttention QUANTIZABLE_MHA_MAP = { @@ -35,6 +39,12 @@ def replace_mha_with_quantizable_layers(model, dtype): return model +def replace_sdpa_with_quantizable_layers(graph_model): + fn_to_module_map = ((F.scaled_dot_product_attention, ScaledDotProductAttention),) + graph_model = TorchFunctionalToModule(fn_to_module_map=fn_to_module_map).apply(graph_model) + return graph_model + + @torch.no_grad() def add_zero_bias_to_linear(model: torch.nn.Module) -> torch.nn.Module: for name, module in model.named_modules(): diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index c8c76d4a1..3a678bdf8 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -40,6 +40,8 @@ from brevitas_examples.llm.llm_quant.ln_affine_merge import replace_rmsnorm_with_torch from brevitas_examples.llm.llm_quant.prepare_for_quantize import add_zero_bias_to_linear from brevitas_examples.llm.llm_quant.prepare_for_quantize import replace_mha_with_quantizable_layers +from brevitas_examples.llm.llm_quant.prepare_for_quantize import \ + replace_sdpa_with_quantizable_layers from brevitas_examples.llm.llm_quant.run_utils import CastFloat16ToFloat32 from brevitas_examples.llm.llm_quant.run_utils import fix_rewriter from brevitas_examples.llm.llm_quant.run_utils import get_fx @@ -180,7 +182,12 @@ def main(args): else: dtype = torch.float16 + # Whether to quantize SDPA with FX + quant_sdpa_fx = args.quant_sdpa and not args.replace_mha + kwargs = {"torch_dtype": dtype} + if quant_sdpa_fx: + kwargs["attn_implementation"] = "sdpa" if args.export_target == 'torch_qcdq': kwargs['torchscript'] = True @@ -199,7 +206,7 @@ def main(args): with CastFloat16ToFloat32(): apply_awq(model, awq_results) - require_fx = True if args.weight_equalization or args.act_equalization == 'fx' or args.ln_affine_merge or args.convert_layernorm_to_rmsnorm else False + require_fx = True if args.weight_equalization or args.act_equalization == 'fx' or args.ln_affine_merge or args.convert_layernorm_to_rmsnorm or quant_sdpa_fx else False # Load the data for calibration and evaluation. calibration_loader = get_dataset_for_model( @@ -280,6 +287,10 @@ def main(args): print("Replace HF MHA with quantizable variants...") model = replace_mha_with_quantizable_layers(model, dtype) print("Replacing done.") + elif quant_sdpa_fx: + print("Replace `F.scaled_dot_product_attention` with QuantSDPA...") + model = replace_sdpa_with_quantizable_layers(model) + print("Replacing done.") if args.weight_equalization: print("Apply weight equalization...") @@ -636,6 +647,10 @@ def parse_args(args): type=float, default=1e-4, help='Minimum value to clamp scale to when using bf16 or fp16 quantization.') + parser.add_argument( + '--quant-sdpa', + action='store_true', + help='Quantize `F.scaled_dot_product_attention` (default: %(default)s)') parser.add_argument( '--replace-mha', action='store_true', diff --git a/tests/brevitas/nn/test_sdpa.py b/tests/brevitas/nn/test_sdpa.py new file mode 100644 index 000000000..b38415ea8 --- /dev/null +++ b/tests/brevitas/nn/test_sdpa.py @@ -0,0 +1,151 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from packaging import version +import pytest +import torch +import torch.nn.functional as F + +from brevitas import torch_version +from brevitas.nn import QuantScaledDotProductAttention +from brevitas.nn import ScaledDotProductAttention +from brevitas.quant import Int8ActPerTensorFloat +from brevitas.quant import Uint8ActPerTensorFloat +from tests.marker import requires_pt_ge + +ATOL = 1e-6 +EMBED_DIM = 9 +HEAD_DIM = 3 +BATCH_SIZE = 2 +SEQUENCE_LENGTH = 4 +PAST_SEQUENCE_LENGTH = 5 +DROPOUT_SEED = 42 + + +class TestScaledDotProductAttention: + + @requires_pt_ge('2.0') + # Check what kwargs are properly filtered and override defaults + def test_sdpa_init(self): + extra_kwargs = { + 'softmax_input_bit_width': 2, + 'attn_output_weights_bit_width': 3, + 'q_scaled_bit_width': 4, + 'k_transposed_bit_width': 5, + 'v_bit_width': 6, + 'sdpa_output_bit_width': 7,} + qm = QuantScaledDotProductAttention( + softmax_input_quant=Int8ActPerTensorFloat, + attn_output_weights_quant=Uint8ActPerTensorFloat, + q_scaled_quant=Int8ActPerTensorFloat, + k_transposed_quant=Int8ActPerTensorFloat, + v_quant=Int8ActPerTensorFloat, + sdpa_output_quant=Int8ActPerTensorFloat, + **extra_kwargs, + ) + + # Check that the `kwargs` have been applied correctly + prefixes = ["softmax_input", "attn_output_weights", "q_scaled", "v", "sdpa_output"] + for k in extra_kwargs.keys(): + checked = False + if "softmax_input_" in k: + assert int(qm.softmax_input_quant.act_quant.bit_width().item()) == extra_kwargs[k] + checked = True + elif "attn_output_weights_" in k: + assert int( + qm.attn_output_weights_quant.act_quant.bit_width().item()) == extra_kwargs[k] + checked = True + elif "q_scaled_" in k: + assert int(qm.q_scaled_quant.act_quant.bit_width().item()) == extra_kwargs[k] + checked = True + elif "k_transposed_" in k: + assert int(qm.k_transposed_quant.act_quant.bit_width().item()) == extra_kwargs[k] + checked = True + elif "v_" in k: + assert int(qm.v_quant.act_quant.bit_width().item()) == extra_kwargs[k] + checked = True + elif "sdpa_output_" in k: + assert int(qm.sdpa_output_quant.act_quant.bit_width().item()) == extra_kwargs[k] + checked = True + assert checked, f"Unmatched kwarg: {k}" + + @requires_pt_ge('2.0') + @pytest.mark.parametrize("dropout_p", [0.0, 0.5]) + @pytest.mark.parametrize("is_causal", [True, False]) + @pytest.mark.parametrize("scale", [None, 0.3]) + @pytest.mark.parametrize("enable_gqa", [False, True]) + @pytest.mark.parametrize("rand_attn_mask", [False, True]) + # Sanity check, since `ScaledDotProductAttention` just calls `F.scaled_dot_product_attention` in its forward function + def test_sdpa_fwd(self, dropout_p, is_causal, scale, enable_gqa, rand_attn_mask): + extra_kwargs = { + "dropout_p": dropout_p, + "is_causal": is_causal, + "scale": scale, + "enable_gqa": enable_gqa,} + if torch_version < version.parse('2.5.0'): + del extra_kwargs["enable_gqa"] + if torch_version < version.parse('2.1.0'): + del extra_kwargs["scale"] + + kv_length = PAST_SEQUENCE_LENGTH + SEQUENCE_LENGTH + m = ScaledDotProductAttention() + q = torch.randn(BATCH_SIZE, HEAD_DIM, SEQUENCE_LENGTH, EMBED_DIM) + k = torch.randn(BATCH_SIZE, HEAD_DIM, kv_length, EMBED_DIM) + v = torch.randn(BATCH_SIZE, HEAD_DIM, kv_length, EMBED_DIM) + if rand_attn_mask and not is_causal: + attn_mask = torch.randint( + low=0, high=2, size=(BATCH_SIZE, 1, SEQUENCE_LENGTH, kv_length), dtype=torch.bool) + else: + attn_mask = None + if dropout_p > 0.0: + torch.manual_seed(DROPOUT_SEED) + ref_out = F.scaled_dot_product_attention(q, k, v, attn_mask, **extra_kwargs) + if dropout_p > 0.0: + torch.manual_seed(DROPOUT_SEED) + out = m(q, k, v, attn_mask, **extra_kwargs) + assert torch.isclose(out, ref_out, atol=ATOL).all() + assert torch.isclose(out, ref_out, atol=ATOL).all() + + @requires_pt_ge('2.0') + @pytest.mark.parametrize("dropout_p", [0.0, 0.5]) + @pytest.mark.parametrize("is_causal", [True, False]) + @pytest.mark.parametrize("scale", [None, 0.3]) + @pytest.mark.parametrize("enable_gqa", [False, True]) + @pytest.mark.parametrize("rand_attn_mask", [False, True]) + def test_sdpa_quant_disabled_fwd(self, dropout_p, is_causal, scale, enable_gqa, rand_attn_mask): + extra_kwargs = { + "dropout_p": dropout_p, + "is_causal": is_causal, + "scale": scale, + "enable_gqa": enable_gqa,} + if torch_version < version.parse('2.5.0'): + del extra_kwargs["enable_gqa"] + if torch_version < version.parse('2.1.0'): + del extra_kwargs["scale"] + + kv_length = PAST_SEQUENCE_LENGTH + SEQUENCE_LENGTH + m = ScaledDotProductAttention() + qm = QuantScaledDotProductAttention( + softmax_input_quant=None, + attn_output_weights_quant=None, + q_scaled_quant=None, + k_transposed_quant=None, + v_quant=None, + sdpa_output_quant=None, + ) + q = torch.randn(BATCH_SIZE, HEAD_DIM, SEQUENCE_LENGTH, EMBED_DIM) + k = torch.randn(BATCH_SIZE, HEAD_DIM, kv_length, EMBED_DIM) + v = torch.randn(BATCH_SIZE, HEAD_DIM, kv_length, EMBED_DIM) + if rand_attn_mask and not is_causal: + attn_mask = torch.randint( + low=0, high=2, size=(BATCH_SIZE, 1, SEQUENCE_LENGTH, kv_length), dtype=torch.bool) + else: + attn_mask = None + if dropout_p > 0.0: + torch.manual_seed(DROPOUT_SEED) + ref_out = m(q, k, v, attn_mask, **extra_kwargs) + if dropout_p > 0.0: + torch.manual_seed(DROPOUT_SEED) + out = qm(q, k, v, attn_mask, **extra_kwargs) + assert torch.isclose(out, ref_out, atol=ATOL).all() + assert torch.isclose(out, ref_out, atol=ATOL).all() diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index 60dd33ac2..61cfae010 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -245,7 +245,8 @@ def test_small_models_acc(caplog, acc_args_and_acc): @pytest_cases.fixture( ids=[ - "opt-replace-mha",], + "opt-replace-mha", + "opt-quant-sdpa",], params=[ { "model": "hf-internal-testing/tiny-random-OPTForCausalLM", # Requires PT>=2.4 to run @@ -253,6 +254,13 @@ def test_small_models_acc(caplog, acc_args_and_acc): "ln_affine_merge": True, "replace_mha": True, "float_ppl": 50016.0, + "quant_ppl": 50016.0}, + { + "model": "hf-internal-testing/tiny-random-OPTForCausalLM", # Requires PT>=2.4 to run + "weight_equalization": True, + "ln_affine_merge": True, + "quant_sdpa": True, + "float_ppl": 50016.0, "quant_ppl": 50016.0},]) def acc_args_and_acc_pt_ge_2_4(default_run_args, request): args = default_run_args @@ -430,7 +438,8 @@ def test_small_models_quant_layer(caplog, layer_args): @pytest_cases.fixture( ids=[ - "opt-replace-mha",], + "opt-replace-mha", + "opt-quant-sdpa",], params=[ { "model": "hf-internal-testing/tiny-random-OPTForCausalLM", # Requires PT>=2.4 to run @@ -439,7 +448,13 @@ def test_small_models_quant_layer(caplog, layer_args): "model.decoder.layers.0.self_attn": "", "model.decoder.layers.0.self_attn.mha": - "",}},]) + "",}}, + { + "model": "hf-internal-testing/tiny-random-OPTForCausalLM", # Requires PT>=2.4 to run + "quant_sdpa": True, + "exp_layer_types": { + "scaled_dot_product_attention": + "",}},]) def layer_args_pt_ge_2_4(default_run_args, request): args = default_run_args layer_dict = request.param