Skip to content

Commit

Permalink
Feat (nn/sdpa): quantization of scaled dot-product attention (#1090)
Browse files Browse the repository at this point in the history
  • Loading branch information
nickfraser authored Dec 6, 2024
1 parent 7a5f77d commit 85c1626
Show file tree
Hide file tree
Showing 8 changed files with 425 additions and 9 deletions.
2 changes: 2 additions & 0 deletions src/brevitas/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from .quant_rnn import QuantRNN
from .quant_scale_bias import QuantScaleBias
from .quant_scale_bias import ScaleBias
from .quant_sdpa import QuantScaledDotProductAttention
from .quant_sdpa import ScaledDotProductAttention
from .quant_upsample import QuantUpsample
from .quant_upsample import QuantUpsamplingBilinear2d
from .quant_upsample import QuantUpsamplingNearest2d
Expand Down
209 changes: 209 additions & 0 deletions src/brevitas/nn/quant_sdpa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
"""
Copyright (C) 2024, Advanced Micro Devices, Inc.
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
Copyright (c) 2011-2013 NYU (Clement Farabet)
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
3. Neither the names of AMD, Facebook, Deepmind Technologies, NYU,
NEC Laboratories America and IDIAP Research Institute nor the names
of its contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.
"""

import math
from typing import Optional, Tuple, Union

import torch
from torch import Tensor
from torch.nn import Module
from torch.nn import Parameter
import torch.nn.functional as F

from brevitas.quant.scaled_int import Int8ActPerTensorFloat
from brevitas.quant.scaled_int import Uint8ActPerTensorFloat

from .quant_activation import QuantIdentity


class ScaledDotProductAttention(Module):

def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
attn_mask: Optional[Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False):
r"""
Args:
query (Tensor): Query tensor; shape :math:`(N, ..., Hq, L, E)`.
key (Tensor): Key tensor; shape :math:`(N, ..., H, S, E)`.
value (Tensor): Value tensor; shape :math:`(N, ..., H, S, Ev)`.
attn_mask (optional Tensor): Attention mask; shape must be broadcastable to the shape of attention weights,
which is :math:`(N,..., L, S)`. Two types of masks are supported.
A boolean mask where a value of True indicates that the element *should* take part in attention.
A float mask of the same type as query, key, value that is added to the attention score.
dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied
is_causal (bool): If set to true, the attention masking is a lower triangular matrix when the mask is a
square matrix. The attention masking has the form of the upper left causal bias due to the alignment
(see :class:`torch.nn.attention.bias.CausalBias`) when the mask is a non-square matrix.
An error is thrown if both attn_mask and is_causal are set.
scale (optional float, keyword-only): Scaling factor applied prior to softmax. If None, the default value is set
to :math:`\frac{1}{\sqrt{E}}`.
enable_gqa (bool): Ignored to make calling interface compatible with PyTorch >v2.5. Always set to False.
Returns:
output (Tensor): Attention output; shape :math:`(N, ..., Hq, L, Ev)`.
Shape legend:
- :math:`N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}`
- :math:`S: \text{Source sequence length}`
- :math:`L: \text{Target sequence length}`
- :math:`E: \text{Embedding dimension of the query and key}`
- :math:`Ev: \text{Embedding dimension of the value}`
- :math:`Hq: \text{Number of heads of query}`
- :math:`H: \text{Number of heads of key and value}`
"""
kwargs = {}
if scale is not None:
kwargs["scale"] = scale
if enable_gqa:
kwargs["enable_gqa"] = enable_gqa
return F.scaled_dot_product_attention(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
**kwargs)


class QuantScaledDotProductAttention(Module):

def __init__(
self,
softmax_input_quant=None,
attn_output_weights_quant=Uint8ActPerTensorFloat,
q_scaled_quant=Int8ActPerTensorFloat,
k_transposed_quant=Int8ActPerTensorFloat,
v_quant=Int8ActPerTensorFloat,
sdpa_output_quant=None,
**kwargs) -> None:
super(QuantScaledDotProductAttention, self).__init__()

def filter_kwargs(prefix):
return {k[len(prefix):]: v for k, v in kwargs.items() if k.startswith(prefix)}

self.q_scaled_quant = QuantIdentity(act_quant=q_scaled_quant, **filter_kwargs('q_scaled_'))
self.k_transposed_quant = QuantIdentity(
act_quant=k_transposed_quant, **filter_kwargs('k_transposed_'))
self.v_quant = QuantIdentity(act_quant=v_quant, **filter_kwargs('v_'))
self.softmax_input_quant = QuantIdentity(
act_quant=softmax_input_quant, **filter_kwargs('softmax_input_'))
self.attn_output_weights_quant = QuantIdentity(
act_quant=attn_output_weights_quant, **filter_kwargs('attn_output_weights_'))
self.sdpa_output_quant = QuantIdentity(
act_quant=sdpa_output_quant, **filter_kwargs('sdpa_output_'))

def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
attn_mask: Optional[Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False):
r"""
Args:
query (Tensor): Query tensor; shape :math:`(N, ..., Hq, L, E)`.
key (Tensor): Key tensor; shape :math:`(N, ..., H, S, E)`.
value (Tensor): Value tensor; shape :math:`(N, ..., H, S, Ev)`.
attn_mask (optional Tensor): Attention mask; shape must be broadcastable to the shape of attention weights,
which is :math:`(N,..., L, S)`. Two types of masks are supported.
A boolean mask where a value of True indicates that the element *should* take part in attention.
A float mask of the same type as query, key, value that is added to the attention score.
dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied
is_causal (bool): If set to true, the attention masking is a lower triangular matrix when the mask is a
square matrix. The attention masking has the form of the upper left causal bias due to the alignment
(see :class:`torch.nn.attention.bias.CausalBias`) when the mask is a non-square matrix.
An error is thrown if both attn_mask and is_causal are set.
scale (optional float, keyword-only): Scaling factor applied prior to softmax. If None, the default value is set
to :math:`\frac{1}{\sqrt{E}}`.
enable_gqa (bool): Ignored to make calling interface compatible with PyTorch >v2.5. Always set to False.
Returns:
output (Tensor): Attention output; shape :math:`(N, ..., Hq, L, Ev)`.
Shape legend:
- :math:`N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}`
- :math:`S: \text{Source sequence length}`
- :math:`L: \text{Target sequence length}`
- :math:`E: \text{Embedding dimension of the query and key}`
- :math:`Ev: \text{Embedding dimension of the value}`
- :math:`Hq: \text{Number of heads of query}`
- :math:`H: \text{Number of heads of key and value}`
"""
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
if attn_mask is None:
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
else:
attn_bias = torch.zeros(size=attn_mask.shape, dtype=query.dtype, device=query.device)
if is_causal:
assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(dtype=query.dtype, device=query.device)

if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
q_scaled = self.q_scaled_quant(query * scale_factor)
k_transpose = self.k_transposed_quant(key.transpose(-2, -1))
attn_weight = q_scaled @ k_transpose
attn_weight += attn_bias
attn_weight = self.softmax_input_quant(attn_weight)
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
attn_weight = self.attn_output_weights_quant(attn_weight)
attn_output = attn_weight @ self.v_quant(value)
attn_output = self.sdpa_output_quant(attn_output)
return attn_output
14 changes: 13 additions & 1 deletion src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,13 +463,25 @@ def generate_quant_maps(
'dtype': dtype,
'device': device}

quant_sdpa_kwargs = {
'softmax_input_quant': None,
'attn_output_weights_quant': attn_output_weights_quant,
'attn_output_weights_signed': 'float' in input_quant_format,
'q_scaled_quant': q_scaled_quant,
'k_transposed_quant': k_transposed_quant,
'v_quant': v_quant,
'attn_output_quant': None,
'dtype': dtype,
'device': device}

layer_map = {
nn.Linear: (qnn.QuantLinear, quant_linear_kwargs),
nn.Conv2d: (qnn.QuantConv2d, quant_conv_kwargs),
'diffusers.models.lora.LoRACompatibleLinear':
(LoRACompatibleQuantLinear, quant_linear_kwargs),
'diffusers.models.lora.LoRACompatibleConv': (LoRACompatibleQuantConv2d, quant_conv_kwargs),
nn.MultiheadAttention: (qnn.QuantMultiheadAttention, quant_mha_kwargs)}
nn.MultiheadAttention: (qnn.QuantMultiheadAttention, quant_mha_kwargs),
qnn.ScaledDotProductAttention: (qnn.QuantScaledDotProductAttention, quant_sdpa_kwargs)}

if quantize_embedding:
quant_embedding_kwargs = {'weight_quant': weight_quant, 'dtype': dtype, 'device': device}
Expand Down
10 changes: 6 additions & 4 deletions src/brevitas_examples/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

## Run

Set the env variable BREVITAS_JIT=1 to speed up the quantization process. Currently unsupported whenever export is also toggled or with MSE based scales/zero-points.
Set the env variable `BREVITAS_JIT=1` to speed up the quantization process. Currently unsupported whenever export is also toggled or with MSE based scales/zero-points.

```bash
usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES]
Expand Down Expand Up @@ -46,8 +46,9 @@ usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES]
[--act-calibration] [--bias-corr] [--ln-affine-merge]
[--convert-layernorm-to-rmsnorm] [--replace-rmsnorm]
[--no-quantize] [--no-float16]
[--scaling-min-val SCALING_MIN_VAL] [--replace-mha]
[--weight-equalization] [--rotation {fx,layerwise,fused_no_fx}]
[--scaling-min-val SCALING_MIN_VAL] [--quant-sdpa]
[--replace-mha] [--weight-equalization]
[--rotation {fx,layerwise,fused_no_fx}]
[--rotation-mode {had,ort}] [--rotation-orphan-sink]
[--act-equalization {None,layerwise,fx}] [--load-awq LOAD_AWQ]
[--export-target {None,onnx_qcdq,torch_qcdq,sharded_torchmlir_group_weight,sharded_packed_torchmlir_group_weight}]
Expand Down Expand Up @@ -160,6 +161,8 @@ options:
--scaling-min-val SCALING_MIN_VAL
Minimum value to clamp scale to when using bf16 or
fp16 quantization.
--quant-sdpa Quantize `F.scaled_dot_product_attention` (default:
False)
--replace-mha Replace HuggingFace Attention with a quantizable
version
--weight-equalization
Expand Down Expand Up @@ -200,5 +203,4 @@ options:
--learned-round-fast-update
Whether to use fast update with learned round.
Prototype (default: False)

```
10 changes: 10 additions & 0 deletions src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@

from packaging import version
import torch
import torch.nn.functional as F
import transformers
from transformers.models.opt.modeling_opt import OPTAttention

from brevitas.graph import ModuleToModuleByClass
from brevitas.graph import TorchFunctionalToModule
from brevitas.nn import QuantScaledDotProductAttention
from brevitas.nn import ScaledDotProductAttention
from brevitas_examples.llm.llm_quant.mha_layers import QuantizableOPTAttention

QUANTIZABLE_MHA_MAP = {
Expand All @@ -35,6 +39,12 @@ def replace_mha_with_quantizable_layers(model, dtype):
return model


def replace_sdpa_with_quantizable_layers(graph_model):
fn_to_module_map = ((F.scaled_dot_product_attention, ScaledDotProductAttention),)
graph_model = TorchFunctionalToModule(fn_to_module_map=fn_to_module_map).apply(graph_model)
return graph_model


@torch.no_grad()
def add_zero_bias_to_linear(model: torch.nn.Module) -> torch.nn.Module:
for name, module in model.named_modules():
Expand Down
17 changes: 16 additions & 1 deletion src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
from brevitas_examples.llm.llm_quant.ln_affine_merge import replace_rmsnorm_with_torch
from brevitas_examples.llm.llm_quant.prepare_for_quantize import add_zero_bias_to_linear
from brevitas_examples.llm.llm_quant.prepare_for_quantize import replace_mha_with_quantizable_layers
from brevitas_examples.llm.llm_quant.prepare_for_quantize import \
replace_sdpa_with_quantizable_layers
from brevitas_examples.llm.llm_quant.run_utils import CastFloat16ToFloat32
from brevitas_examples.llm.llm_quant.run_utils import fix_rewriter
from brevitas_examples.llm.llm_quant.run_utils import get_fx
Expand Down Expand Up @@ -180,7 +182,12 @@ def main(args):
else:
dtype = torch.float16

# Whether to quantize SDPA with FX
quant_sdpa_fx = args.quant_sdpa and not args.replace_mha

kwargs = {"torch_dtype": dtype}
if quant_sdpa_fx:
kwargs["attn_implementation"] = "sdpa"

if args.export_target == 'torch_qcdq':
kwargs['torchscript'] = True
Expand All @@ -199,7 +206,7 @@ def main(args):
with CastFloat16ToFloat32():
apply_awq(model, awq_results)

require_fx = True if args.weight_equalization or args.act_equalization == 'fx' or args.ln_affine_merge or args.convert_layernorm_to_rmsnorm else False
require_fx = True if args.weight_equalization or args.act_equalization == 'fx' or args.ln_affine_merge or args.convert_layernorm_to_rmsnorm or quant_sdpa_fx else False

# Load the data for calibration and evaluation.
calibration_loader = get_dataset_for_model(
Expand Down Expand Up @@ -280,6 +287,10 @@ def main(args):
print("Replace HF MHA with quantizable variants...")
model = replace_mha_with_quantizable_layers(model, dtype)
print("Replacing done.")
elif quant_sdpa_fx:
print("Replace `F.scaled_dot_product_attention` with QuantSDPA...")
model = replace_sdpa_with_quantizable_layers(model)
print("Replacing done.")

if args.weight_equalization:
print("Apply weight equalization...")
Expand Down Expand Up @@ -636,6 +647,10 @@ def parse_args(args):
type=float,
default=1e-4,
help='Minimum value to clamp scale to when using bf16 or fp16 quantization.')
parser.add_argument(
'--quant-sdpa',
action='store_true',
help='Quantize `F.scaled_dot_product_attention` (default: %(default)s)')
parser.add_argument(
'--replace-mha',
action='store_true',
Expand Down
Loading

0 comments on commit 85c1626

Please sign in to comment.