From 0def533c7a8429c642e63677b318c8ce86348e6c Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Fri, 8 Nov 2024 11:13:10 +0000 Subject: [PATCH] Fix (nn/sdpa): formatting --- src/brevitas/graph/standardize.py | 2 +- src/brevitas/nn/quant_sdpa.py | 60 +++++++++++++++++++++++-------- 2 files changed, 47 insertions(+), 15 deletions(-) diff --git a/src/brevitas/graph/standardize.py b/src/brevitas/graph/standardize.py index 12692eddf..4c9233e08 100644 --- a/src/brevitas/graph/standardize.py +++ b/src/brevitas/graph/standardize.py @@ -10,7 +10,7 @@ from brevitas.fx import GraphModule from brevitas.fx import immutable_dict from brevitas.fx import Node -from brevitas.nn.quant_sdpa import ScaledDotProductAttention +from brevitas.nn.quant_sdpa import ScaledDotProductAttention from .base import FnToModule from .base import GraphTransform diff --git a/src/brevitas/nn/quant_sdpa.py b/src/brevitas/nn/quant_sdpa.py index 368eaa565..9927ca41a 100644 --- a/src/brevitas/nn/quant_sdpa.py +++ b/src/brevitas/nn/quant_sdpa.py @@ -53,7 +53,17 @@ class ScaledDotProductAttention(Module): - def forward(self, query: Tensor, key: Tensor, value: Tensor, attn_mask: Optional[Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False): + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + attn_mask: Optional[Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False): r""" Args: query (Tensor): Query tensor; shape :math:`(N, ..., Hq, L, E)`. @@ -71,10 +81,10 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, attn_mask: Optional scale (optional float, keyword-only): Scaling factor applied prior to softmax. If None, the default value is set to :math:`\frac{1}{\sqrt{E}}`. enable_gqa (bool): Ignored to make calling interface compatible with PyTorch >v2.5. Always set to False. - + Returns: output (Tensor): Attention output; shape :math:`(N, ..., Hq, L, Ev)`. - + Shape legend: - :math:`N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}` - :math:`S: \text{Source sequence length}` @@ -84,22 +94,35 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, attn_mask: Optional - :math:`Hq: \text{Number of heads of query}` - :math:`H: \text{Number of heads of key and value}` """ - return F.scaled_dot_product_attention(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale) + return F.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale) class QuantScaledDotProductAttention(Module): - def __init__(self, query_quant=Int8ActPerTensorFloat, key_quant=Int8ActPerTensorFloat, value_quant=Int8ActPerTensorFloat, softmax_input_quant=Int8ActPerTensorFloat, softmax_output_quant=Uint8ActPerTensorFloat, attn_output_quant=None, **kwargs) -> None: + + def __init__( + self, + query_quant=Int8ActPerTensorFloat, + key_quant=Int8ActPerTensorFloat, + value_quant=Int8ActPerTensorFloat, + softmax_input_quant=Int8ActPerTensorFloat, + softmax_output_quant=Uint8ActPerTensorFloat, + attn_output_quant=None, + **kwargs) -> None: super(QuantScaledDotProductAttention, self).__init__() def filter_kwargs(prefix): return {k[len(prefix):]: v for k, v in kwargs.items() if k.startswith(prefix)} - self.query_quant = QuantIdentity( - act_quant=query_quant, **filter_kwargs('query_')) - self.key_quant = QuantIdentity( - act_quant=key_quant, **filter_kwargs('key_')) - self.value_quant = QuantIdentity( - act_quant=value_quant, **filter_kwargs('value_')) + self.query_quant = QuantIdentity(act_quant=query_quant, **filter_kwargs('query_')) + self.key_quant = QuantIdentity(act_quant=key_quant, **filter_kwargs('key_')) + self.value_quant = QuantIdentity(act_quant=value_quant, **filter_kwargs('value_')) self.softmax_input_quant = QuantIdentity( act_quant=softmax_input_quant, **filter_kwargs('softmax_input_')) self.softmax_output_quant = QuantIdentity( @@ -107,7 +130,16 @@ def filter_kwargs(prefix): self.attn_output_quant = QuantIdentity( act_quant=attn_output_quant, **filter_kwargs('attn_output_')) - def forward(self, query: Tensor, key: Tensor, value: Tensor, attn_mask: Optional[Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False): + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + attn_mask: Optional[Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False): r""" Args: query (Tensor): Query tensor; shape :math:`(N, ..., Hq, L, E)`. @@ -125,10 +157,10 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, attn_mask: Optional scale (optional float, keyword-only): Scaling factor applied prior to softmax. If None, the default value is set to :math:`\frac{1}{\sqrt{E}}`. enable_gqa (bool): Ignored to make calling interface compatible with PyTorch >v2.5. Always set to False. - + Returns: output (Tensor): Attention output; shape :math:`(N, ..., Hq, L, Ev)`. - + Shape legend: - :math:`N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}` - :math:`S: \text{Source sequence length}`